├── .swift-version ├── .dockerignore ├── Sources ├── CXGBoost │ ├── module.modulemap │ └── shim.h └── XGBoost │ ├── Logging.swift │ ├── String.swift │ ├── Dictionary.swift │ ├── Safe.swift │ ├── TensorFlow.swift │ ├── Feature.swift │ ├── Parameter.swift │ ├── Enums.swift │ ├── Plot.swift │ ├── Python.swift │ ├── Array.swift │ ├── CrossValidation.swift │ ├── Train.swift │ ├── DMatrix.swift │ └── Booster.swift ├── Tests ├── LinuxMain.swift └── XGBoostTests │ ├── PythonTests.swift │ ├── XCTestManifests.swift │ ├── TensorFlowTests.swift │ ├── PlotTests.swift │ ├── ExampleTests.swift │ ├── Utilities.swift │ ├── ArrayTests.swift │ ├── XGBoostDocsTests.swift │ ├── CrossValidationTests.swift │ ├── TrainTests.swift │ ├── DMatrixTests.swift │ └── BoosterTests.swift ├── .gitignore ├── .swiftlint.yml ├── install.sh ├── docker-compose.yml ├── .github └── workflows │ ├── ubuntu_test.yml │ ├── lint.yml │ └── macos_test.yml ├── Examples ├── Data │ ├── data.csv │ ├── data.svm.txt │ └── veterans_lung_cancer.csv ├── ExternalMemory │ └── main.swift ├── PredictLeafIndices │ └── main.swift ├── EvalsResult │ └── main.swift ├── PredictFirstNTree │ └── main.swift ├── BoostFromPrediction │ └── main.swift ├── GeneralizedLinearModel │ └── main.swift ├── CrossValidation │ └── main.swift ├── CustomObjective │ └── main.swift └── AftSurvival │ └── main.swift ├── doc ├── index.rst └── swift_intro.rst ├── .travis.yml ├── Makefile ├── Dockerfile ├── Package.resolved ├── Package.swift ├── README.md └── LICENSE /.swift-version: -------------------------------------------------------------------------------- 1 | 5.2.4 -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .build 2 | .github 3 | .history 4 | .idea 5 | .vscode 6 | .DS_Store 7 | *.xgboost 8 | -------------------------------------------------------------------------------- /Sources/CXGBoost/module.modulemap: -------------------------------------------------------------------------------- 1 | module CXGBoost [system] { 2 | header "shim.h" 3 | link "xgboost" 4 | export * 5 | } -------------------------------------------------------------------------------- /Sources/CXGBoost/shim.h: -------------------------------------------------------------------------------- 1 | #ifndef cxgboost_shim_h 2 | #define cxgboost_shim_h 3 | 4 | #include 5 | #include 6 | 7 | #endif -------------------------------------------------------------------------------- /Tests/LinuxMain.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | @testable import XGBoostTests 4 | 5 | var tests = [XCTestCaseEntry]() 6 | 7 | tests += XGBoostAllTests() 8 | 9 | XCTMain(tests) 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .build 2 | .build-mac 3 | .swiftpm 4 | .history 5 | .idea 6 | .code 7 | .vscode 8 | .DS_Store 9 | Documentation 10 | *.xgboost 11 | *.buffer 12 | *.png 13 | *.svg 14 | venv 15 | .metals 16 | -------------------------------------------------------------------------------- /Sources/XGBoost/Logging.swift: -------------------------------------------------------------------------------- 1 | import Logging 2 | 3 | /// Logger instance used for messages. 4 | let logger = Logger(label: "swiftxgboost") 5 | 6 | /// - Parameters: Message to log. 7 | func log(_ message: String) { 8 | // TODO: Rabit logging. 9 | logger.info("\(message)") 10 | } 11 | -------------------------------------------------------------------------------- /.swiftlint.yml: -------------------------------------------------------------------------------- 1 | disabled_rules: 2 | - trailing_comma 3 | - identifier_name 4 | - todo 5 | - force_try 6 | - large_tuple 7 | - force_cast 8 | - type_name 9 | - type_body_length 10 | - line_length 11 | - file_length 12 | - function_body_length 13 | - cyclomatic_complexity 14 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | tmp_dir=$(mktemp -d -t xgboost-XXXXXX) 2 | echo "Will work in $tmp_dir." 3 | 4 | cd $tmp_dir 5 | git clone https://github.com/dmlc/xgboost 6 | cd xgboost 7 | git checkout tags/v1.1.1 8 | git submodule update --init --recursive 9 | mkdir build 10 | cd build 11 | cmake .. 12 | make 13 | make install 14 | ldconfig 15 | 16 | echo "Success!" 17 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.7' 2 | services: 3 | swiftxgboost: 4 | build: 5 | context: . 6 | target: build 7 | image: swiftxgboost 8 | volumes: 9 | - "./:/app" 10 | 11 | test: 12 | build: 13 | context: . 14 | target: build 15 | image: swiftxgboost 16 | volumes: 17 | - "./:/app" 18 | command: swift test 19 | -------------------------------------------------------------------------------- /.github/workflows/ubuntu_test.yml: -------------------------------------------------------------------------------- 1 | name: Ubuntu 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | UbuntuTests: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v1 14 | - name: Install dependencies 15 | run: sudo apt-get install -y make 16 | - name: Run 17 | run: make test 18 | -------------------------------------------------------------------------------- /Examples/Data/data.csv: -------------------------------------------------------------------------------- 1 | 0,69.0,60.0,7.0,0,0,0,1,1,0,1,0,0 2 | 1,411.0,64.0,70.0,5.0,0,0,0,1,0,1,1,0 3 | 0,228.0,38.0,60.0,3.0,0,0,0,1,1,0,1,0 4 | 0,126.0,63.0,60.0,9.0,0,0,0,1,0,1,1,0 5 | 0,118.0,65.0,70.0,11.0,0,0,0,1,0,1,1,0 6 | 1,10.0,49.0,20.0,5.0,0,0,0,1,1,0,1,0 7 | 1,82.0,69.0,40.0,10.0,0,0,0,1,0,1,1,0 8 | 1,110.0,68.0,80.0,29.0,0,0,0,1,1,0,1,0 9 | 0,314.0,43.0,50.0,18.0,0,0,0,1,1,0,1,0 10 | 0,inf,70.0,70.0,6.0,0,0,0,1,1,0,1,0 11 | 0,42.0,81.0,60.0,4.0,0,0,0,1,1,0,1,0 12 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | ###################### 2 | XGBoost Swift Package 3 | ###################### 4 | This page contains links to all the related documents on the swift package. 5 | The package itself is hosted at `github `_. 6 | 7 | ******** 8 | Contents 9 | ******** 10 | 11 | .. toctree:: 12 | swift_intro 13 | Swift API documentation 14 | Swift examples 15 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: generic 2 | os: osx 3 | osx_image: xcode11.5 4 | addons: 5 | homebrew: 6 | update: true 7 | packages: 8 | - make 9 | - xgboost 10 | - sourcekitten 11 | before_install: 12 | - gem install jazzy 13 | after_success: 14 | - brew info xgboost 15 | - gmake documentation 16 | deploy: 17 | provider: pages 18 | skip_cleanup: true 19 | github_token: $GH_TOKEN 20 | local_dir: Documentation 21 | keep_history: true 22 | branches: 23 | only: 24 | - master 25 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/PythonTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | @testable import XGBoost 4 | 5 | final class PythonTests: XCTestCase { 6 | func testMakeNumpyArrayWithShape() throws { 7 | let values = [1.0, 2.0, 3.0, 4.0] 8 | let shape = Shape(2, 2) 9 | let npValues = values.makeNumpyArray(shape: shape) 10 | XCTAssertEqual(try npValues.dataShape(), shape) 11 | } 12 | 13 | static var allTests = [ 14 | ("testMakeNumpyArrayWithShape", testMakeNumpyArrayWithShape), 15 | ] 16 | } 17 | -------------------------------------------------------------------------------- /Sources/XGBoost/String.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | public extension String { 4 | /// String in C-compatible format, passable to the C-functions. 5 | var cCompatible: UnsafePointer? { 6 | withCString { (cString) -> UnsafePointer? in 7 | let length = strlen(cString) + 1 8 | let name: UnsafeMutablePointer? = .allocate(capacity: length) 9 | name!.initialize(from: cString, count: length) 10 | return UnsafePointer(name) 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /Sources/XGBoost/Dictionary.swift: -------------------------------------------------------------------------------- 1 | public extension Dictionary { 2 | /// - Parameter key: Key to retrieve from self. 3 | /// - Parameter or: Value that will be set for `key` if `self[key]` is nil. 4 | /// - Returns: Value for `key` or `or` if `key` is not set. 5 | subscript(key: Key, or def: Value) -> Value { 6 | mutating get { 7 | self[key] ?? { 8 | self[key] = def 9 | return def 10 | }() 11 | } 12 | set { self[key] = newValue } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: documentation 2 | 3 | test: 4 | docker-compose run test 5 | 6 | format: 7 | docker-compose run swiftxgboost swiftformat . --exclude Sources/XGBoost/Python.swift,Tests/XGBoostTests/ArrayTests.swift,.build,.history 8 | 9 | documentation: 10 | @sourcekitten doc \ 11 | --spm \ 12 | --module-name XGBoost \ 13 | > SK_XGBoost.json 14 | @jazzy \ 15 | --output Documentation \ 16 | --github_url https://github.com/kongzii/SwiftXGBoost \ 17 | --min-acl public \ 18 | --sourcekitten-sourcefile SK_XGBoost.json 19 | @rm SK_*.json 20 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: LintTest 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | SwiftLint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v1 14 | - name: SwiftLint 15 | uses: norio-nomura/action-swiftlint@3.1.0 16 | SwiftFormat: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v1 20 | - name: Run formatting 21 | run: make format 22 | - name: Run git diff 23 | run: git diff --exit-code 24 | -------------------------------------------------------------------------------- /Sources/XGBoost/Safe.swift: -------------------------------------------------------------------------------- 1 | import CXGBoost 2 | import Foundation 3 | 4 | enum XGBoostError: Error { 5 | /// Runtime error caused when output of C-API functions != 0. 6 | case runtimeError(String) 7 | } 8 | 9 | enum ValueError: Error { 10 | /// Runtime error caused on incompatible inputs. 11 | case runtimeError(String) 12 | } 13 | 14 | /// Helper function to handle C-API calls. 15 | /// 16 | /// - Parameter call: Closure returning output of XGBoost C-API function. 17 | func safe(call: () throws -> Int32) throws { 18 | if try call() != 0 { 19 | throw XGBoostError.runtimeError(String(cString: XGBGetLastError())) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/XCTestManifests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | #if !os(macOS) 4 | public func XGBoostAllTests() -> [XCTestCaseEntry] { 5 | [ 6 | testCase(ExampleTests.allTests), 7 | testCase(DMatrixTests.allTests), 8 | testCase(BoosterTests.allTests), 9 | testCase(PlotTests.allTests), 10 | testCase(TensorFlowTests.allTests), 11 | testCase(XGBoostDocsTests.allTests), 12 | testCase(TrainTests.allTests), 13 | testCase(CrossValidationTests.allTests), 14 | testCase(ArrayTests.allTests), 15 | testCase(PythonTests.allTests), 16 | ] 17 | } 18 | #endif 19 | -------------------------------------------------------------------------------- /Sources/XGBoost/TensorFlow.swift: -------------------------------------------------------------------------------- 1 | #if canImport(TensorFlow) 2 | 3 | import TensorFlow 4 | 5 | extension Tensor: FloatData where Scalar == Float { 6 | /// Comfortance for FloatData. 7 | public func data() throws -> [Float] { 8 | scalars 9 | } 10 | } 11 | 12 | extension Tensor: UInt32Data where Scalar == UInt32 { 13 | /// Comfortance for UInt32Data. 14 | public func data() throws -> [UInt32] { 15 | scalars 16 | } 17 | } 18 | 19 | extension Tensor: ShapeData { 20 | /// Comfortance for ShapeData. 21 | public func dataShape() throws -> Shape { 22 | Shape(shape) 23 | } 24 | } 25 | 26 | #endif 27 | -------------------------------------------------------------------------------- /Sources/XGBoost/Feature.swift: -------------------------------------------------------------------------------- 1 | /// Type of feature. 2 | public enum FeatureType: String { 3 | case quantitative = "q" 4 | case indicator = "i" 5 | } 6 | 7 | /// Struct holding information about feature. 8 | public struct Feature: Equatable { 9 | public let name: String 10 | public let type: FeatureType 11 | 12 | /// - Parameter: Name of feature. 13 | /// - Parameter: Type of feature. 14 | public init(_ name: String, _ type: FeatureType) { 15 | self.name = name 16 | self.type = type 17 | } 18 | 19 | /// - Parameter name: Name of feature. 20 | /// - Parameter type: Type of feature. 21 | public init(name: String, type: FeatureType) { 22 | self.name = name 23 | self.type = type 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/TensorFlowTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | #if canImport(TensorFlow) 4 | import TensorFlow 5 | #endif 6 | 7 | @testable import XGBoost 8 | 9 | final class TensorFlowTests: XCTestCase { 10 | #if canImport(TensorFlow) 11 | func testDataFromTensor() throws { 12 | let x = Tensor(shape: TensorShape([2, 3]), scalars: [1, 2, 3, 4, 5, 6]) 13 | let data = try DMatrix(name: "test", from: x) 14 | XCTAssertEqual(try data.rowCount(), 2) 15 | XCTAssertEqual(try data.columnCount(), 3) 16 | } 17 | 18 | static var allTests = [ 19 | ("testDataFromTensor", testDataFromTensor), 20 | ] 21 | #else 22 | static var allTests = [(String, (XCTestCase) -> () -> Void)]() 23 | #endif 24 | } 25 | -------------------------------------------------------------------------------- /Sources/XGBoost/Parameter.swift: -------------------------------------------------------------------------------- 1 | /// Struct representating parameter for Booster. 2 | /// See https://xgboost.readthedocs.io/en/latest/parameter.html for available parameters. 3 | public struct Parameter: Equatable, Codable { 4 | var count: Int = 0 5 | 6 | public let name: String 7 | public let value: String 8 | 9 | /// - Parameter: Name of parameter. 10 | /// - Parameter: Value for parameter, can be anything convertible to string. 11 | public init(_ name: String, _ value: CustomStringConvertible) { 12 | self.name = name 13 | self.value = "\(value)" 14 | } 15 | 16 | /// - Parameter name: Name of parameter. 17 | /// - Parameter value: Value for parameter, can be anything convertible to string. 18 | public init(name: String, value: CustomStringConvertible) { 19 | self.name = name 20 | self.value = "\(value)" 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /.github/workflows/macos_test.yml: -------------------------------------------------------------------------------- 1 | name: MacOS 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | CodeCoverageAndMacOSTests: 11 | runs-on: macos-latest 12 | steps: 13 | - uses: actions/checkout@v1 14 | - name: Update brew 15 | run: brew update 16 | - name: Install dependencies from brew 17 | run: brew install xgboost llvm 18 | - name: Install dependencies for python 19 | run: pip3 install xgboost coverage numpy pandas 20 | - name: Test and create CodeCoverage 21 | run: swift test --enable-code-coverage 22 | - name: Export coverage 23 | run: xcrun llvm-cov show .build/x86_64-apple-macosx/debug/SwiftXGBoostPackageTests.xctest/Contents/MacOS/SwiftXGBoostPackageTests --instr-profile .build/x86_64-apple-macosx/debug/codecov/default.profdata --ignore-filename-regex=Tests/ > coverage.txt 24 | - name: Upload CodeCoverage 25 | run: bash <(curl -s https://codecov.io/bash) 26 | -------------------------------------------------------------------------------- /Examples/ExternalMemory/main.swift: -------------------------------------------------------------------------------- 1 | // This example is based on https://github.com/dmlc/xgboost/blob/master/demo/guide-python/external_memory.py 2 | 3 | import Foundation 4 | import XGBoost 5 | 6 | let train = try DMatrix( 7 | name: "train", 8 | from: "Examples/Data/agaricus.txt.train", 9 | format: .libsvm, 10 | useCache: true 11 | ) 12 | let test = try DMatrix( 13 | name: "test", 14 | from: "Examples/Data/agaricus.txt.test", 15 | format: .libsvm, 16 | useCache: true 17 | ) 18 | 19 | let parameters = [ 20 | Parameter("seed", 0), 21 | Parameter("eta", 1), 22 | Parameter("max_depth", 2), 23 | Parameter("objective", "binary:logistic"), 24 | ] 25 | 26 | let booster = try Booster( 27 | with: [train, test], 28 | parameters: parameters 29 | ) 30 | 31 | try booster.train( 32 | iterations: 2, 33 | trainingData: train, 34 | evaluationData: [train, test] 35 | ) { _, iteration, evaluation, _ in 36 | print(iteration, ":", evaluation!) 37 | return .next 38 | } 39 | -------------------------------------------------------------------------------- /Examples/Data/data.svm.txt: -------------------------------------------------------------------------------- 1 | 0 1:1 9:1 19:1 21:1 24:1 34:1 36:1 39:1 42:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 117:1 122:1 2 | 1 3:1 9:1 19:1 21:1 30:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 118:1 124:1 3 | 0 1:1 9:1 20:1 21:1 24:1 34:1 36:1 39:1 41:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 117:1 122:1 4 | 0 3:1 9:1 19:1 21:1 24:1 34:1 36:1 39:1 51:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 122:1 5 | 0 4:1 7:1 11:1 22:1 29:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 119:1 124:1 6 | 0 3:1 10:1 20:1 21:1 23:1 34:1 37:1 40:1 42:1 54:1 55:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 118:1 126:1 7 | 1 3:1 9:1 11:1 21:1 30:1 34:1 36:1 40:1 51:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 117:1 124:1 8 | 0 1:1 9:1 20:1 21:1 23:1 34:1 36:1 39:1 42:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 117:1 120:1 9 | 1 3:1 9:1 19:1 21:1 30:1 34:1 36:1 40:1 48:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 118:1 120:1 10 | -------------------------------------------------------------------------------- /Examples/PredictLeafIndices/main.swift: -------------------------------------------------------------------------------- 1 | // This example is based on https://github.com/dmlc/xgboost/blob/master/demo/guide-python/predict_leaf_indices.py 2 | 3 | import Foundation 4 | import XGBoost 5 | 6 | let train = try DMatrix( 7 | name: "train", 8 | from: "Examples/Data/agaricus.txt.train", 9 | format: .libsvm 10 | ) 11 | let test = try DMatrix( 12 | name: "test", 13 | from: "Examples/Data/agaricus.txt.test", 14 | format: .libsvm 15 | ) 16 | 17 | let parameters = [ 18 | Parameter("seed", 0), 19 | Parameter("eta", 1), 20 | Parameter("max_depth", 2), 21 | Parameter("objective", "binary:logistic"), 22 | ] 23 | 24 | let booster = try Booster( 25 | with: [train, test], 26 | parameters: parameters 27 | ) 28 | 29 | try booster.train( 30 | iterations: 3, 31 | trainingData: train 32 | ) 33 | 34 | let leafIndex1 = try booster.predict( 35 | from: test, 36 | treeLimit: 2, 37 | predictionLeaf: true 38 | ).makeNumpyArray() 39 | let leafIndex2 = try booster.predict( 40 | from: test, 41 | predictionLeaf: true 42 | ).makeNumpyArray() 43 | 44 | print(leafIndex1) 45 | print(leafIndex2) 46 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM swift:5.2 AS build 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | 5 | RUN apt-get update \ 6 | && apt-get upgrade -y \ 7 | && apt-get install -y \ 8 | wget curl clang libblocksruntime-dev libxml2-dev \ 9 | libxml2 zlib1g-dev git python3.8 python3.8-dev libpython3.8-dev python3-pip \ 10 | pkg-config libatomic1 netcat-openbsd libfreetype6-dev 11 | 12 | RUN wget https://github.com/Kitware/CMake/releases/download/v3.15.2/cmake-3.15.2.tar.gz \ 13 | && apt-get remove --purge -y cmake \ 14 | && tar -zxvf cmake-3.15.2.tar.gz \ 15 | && rm cmake-3.15.2.tar.gz \ 16 | && cd cmake-3.15.2 \ 17 | && ./bootstrap \ 18 | && make \ 19 | && make install \ 20 | && cmake --version 21 | 22 | RUN git clone https://github.com/yonaskolb/Mint.git && cd Mint && swift run mint install yonaskolb/mint 23 | 24 | RUN mint install nicklockwood/SwiftFormat && mint install realm/SwiftLint 25 | 26 | RUN python3.8 -m pip install xgboost numpy pandas 27 | 28 | COPY install.sh install.sh 29 | RUN chmod +x install.sh && ./install.sh 30 | 31 | WORKDIR /app 32 | 33 | FROM build AS app 34 | 35 | COPY . ./ 36 | 37 | CMD ["swift", "test"] 38 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/PlotTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | @testable import XGBoost 4 | 5 | final class PlotTests: XCTestCase { 6 | func testSaveImportanceGraph() throws { 7 | let randomArray = (0 ..< 50).map { _ in Float.random(in: 0 ..< 2) } 8 | let label = (0 ..< 10).map { _ in Float([0, 1].randomElement()!) } 9 | let features = (0 ..< 5).map { Feature(name: "Name-\($0)", type: .quantitative) } 10 | let data = try DMatrix( 11 | name: "data", 12 | from: randomArray, 13 | shape: Shape(10, 5), 14 | features: features, 15 | label: label, 16 | threads: 1 17 | ) 18 | let booster = try Booster(with: [data]) 19 | try booster.train( 20 | iterations: 10, 21 | trainingData: data 22 | ) 23 | 24 | let graphFile = temporaryFile() 25 | let featureMapFile = temporaryFile() 26 | 27 | try data.saveFeatureMap(to: featureMapFile) 28 | try booster.saveImportanceGraph(to: graphFile, featureMap: featureMapFile) 29 | 30 | XCTAssertTrue(FileManager.default.fileExists(atPath: featureMapFile)) 31 | XCTAssertTrue(FileManager.default.fileExists(atPath: graphFile + ".svg")) 32 | } 33 | 34 | static var allTests = [ 35 | ("testSaveImportanceGraph", testSaveImportanceGraph), 36 | ] 37 | } 38 | -------------------------------------------------------------------------------- /Examples/EvalsResult/main.swift: -------------------------------------------------------------------------------- 1 | // This example is based on https://github.com/dmlc/xgboost/blob/master/demo/guide-python/custom_rmsle.py 2 | 3 | import Foundation 4 | import XGBoost 5 | 6 | let train = try DMatrix( 7 | name: "train", 8 | from: "Examples/Data/agaricus.txt.train", 9 | format: .libsvm 10 | ) 11 | let test = try DMatrix( 12 | name: "test", 13 | from: "Examples/Data/agaricus.txt.test", 14 | format: .libsvm 15 | ) 16 | 17 | let parameters = [ 18 | Parameter("seed", 0), 19 | Parameter("max_depth", 2), 20 | Parameter("objective", "binary:logistic"), 21 | Parameter("eval_metric", "logloss"), 22 | Parameter("eval_metric", "error"), 23 | ] 24 | 25 | let booster = try Booster( 26 | with: [train, test], 27 | parameters: parameters 28 | ) 29 | 30 | try booster.train( 31 | iterations: 2, 32 | trainingData: train, 33 | evaluationData: [train, test] 34 | ) { _, _, evaluation, _ in 35 | // Note: `evaluation` will be `nil`, if `evaluationData` are not provided. 36 | print("Training logloss:", evaluation![train.name]!["logloss"]!) 37 | print("Training error:", evaluation![train.name]!["error"]!) 38 | print("Testing logloss:", evaluation![test.name]!["logloss"]!) 39 | print("Testing error:", evaluation![test.name]!["error"]!) 40 | print("---") 41 | 42 | // Note: This callback can return `.stop` to break training loop. 43 | return .next 44 | } 45 | -------------------------------------------------------------------------------- /Examples/PredictFirstNTree/main.swift: -------------------------------------------------------------------------------- 1 | // This example is based on https://github.com/dmlc/xgboost/blob/master/demo/guide-python/predict_first_ntree.py 2 | 3 | import Foundation 4 | import XGBoost 5 | 6 | let train = try DMatrix( 7 | name: "train", 8 | from: "Examples/Data/agaricus.txt.train", 9 | format: .libsvm 10 | ) 11 | let test = try DMatrix( 12 | name: "test", 13 | from: "Examples/Data/agaricus.txt.test", 14 | format: .libsvm 15 | ) 16 | 17 | let parameters = [ 18 | Parameter("seed", 0), 19 | Parameter("eta", 1), 20 | Parameter("max_depth", 2), 21 | Parameter("objective", "binary:logistic"), 22 | ] 23 | 24 | let booster = try Booster( 25 | with: [train, test], 26 | parameters: parameters 27 | ) 28 | 29 | try booster.train( 30 | iterations: 3, 31 | trainingData: train 32 | ) 33 | 34 | let labels = try test.get(field: .label) 35 | 36 | // Predict using first 1 tree 37 | let yPred1 = try booster.predict(from: test, treeLimit: 1) 38 | let error1 = zip(yPred1, labels).map { 39 | ($0 > 0.5) != ($1 == 1) ? 1 : 0 40 | }.sum() / Float(labels.count) 41 | 42 | // By default, we predict using all the trees 43 | let yPred2 = try booster.predict(from: test) 44 | let error2 = zip(yPred2, labels).map { 45 | ($0 > 0.5) != ($1 == 1) ? 1 : 0 46 | }.sum() / Float(labels.count) 47 | 48 | print( 49 | """ 50 | Error yPred1: \(error1) 51 | Error yPred2: \(error2) 52 | """ 53 | ) 54 | -------------------------------------------------------------------------------- /Examples/BoostFromPrediction/main.swift: -------------------------------------------------------------------------------- 1 | // This example is based on https://github.com/dmlc/xgboost/blob/master/demo/guide-python/boost_from_prediction.py 2 | 3 | import Foundation 4 | import PythonKit 5 | import XGBoost 6 | 7 | let train = try DMatrix( 8 | name: "train", 9 | from: "Examples/Data/agaricus.txt.train", 10 | format: .libsvm 11 | ) 12 | let test = try DMatrix( 13 | name: "test", 14 | from: "Examples/Data/agaricus.txt.test", 15 | format: .libsvm 16 | ) 17 | 18 | let parameters = [ 19 | Parameter("seed", 0), 20 | Parameter("max_depth", 2), 21 | Parameter("eta", 1), 22 | Parameter("objective", "binary:logistic"), 23 | ] 24 | 25 | let booster = try Booster( 26 | with: [train], 27 | parameters: parameters 28 | ) 29 | 30 | try booster.train( 31 | iterations: 1, 32 | trainingData: train 33 | ) 34 | 35 | // Note 36 | // We need the margin value instead of transformed prediction in set_base_margin. 37 | // Do predict with `outputMargin: true`, will always give you margin values before logistic transformation. 38 | let predictedTrain = try booster.predict( 39 | from: train, outputMargin: true 40 | ) 41 | let predictedTest = try booster.predict( 42 | from: test, outputMargin: true 43 | ) 44 | 45 | try train.set(field: .baseMargin, values: predictedTrain) 46 | try test.set(field: .baseMargin, values: predictedTest) 47 | 48 | let booster2 = try Booster( 49 | with: [train], 50 | parameters: parameters 51 | ) 52 | 53 | try booster2.train( 54 | iterations: 1, 55 | trainingData: train 56 | ) 57 | 58 | let predictedTrain2 = try booster2.predict( 59 | from: train, outputMargin: true 60 | ) 61 | let predictedTest2 = try booster2.predict( 62 | from: test, outputMargin: true 63 | ) 64 | 65 | print(predictedTest2) 66 | -------------------------------------------------------------------------------- /Examples/GeneralizedLinearModel/main.swift: -------------------------------------------------------------------------------- 1 | // This example is based on https://github.com/dmlc/xgboost/blob/master/demo/guide-python/generalized_linear_model.py 2 | 3 | import Foundation 4 | import XGBoost 5 | 6 | // This script demonstrate how to fit generalized linear model in xgboost 7 | // basically, we are using linear model, instead of tree for our boosters 8 | 9 | let train = try DMatrix( 10 | name: "train", 11 | from: "Examples/Data/agaricus.txt.train", 12 | format: .libsvm 13 | ) 14 | let test = try DMatrix( 15 | name: "test", 16 | from: "Examples/Data/agaricus.txt.test", 17 | format: .libsvm 18 | ) 19 | 20 | // change booster to gblinear, so that we are fitting a linear model 21 | // alpha is the L1 regularizer 22 | // lambda is the L2 regularizer 23 | // you can also set lambda_bias which is L2 regularizer on the bias term 24 | var parameters = [ 25 | Parameter("seed", 0), 26 | Parameter("lambda", 1), 27 | Parameter("alpha", 0.0001), 28 | Parameter("booster", "gblinear"), 29 | Parameter("objective", "binary:logistic"), 30 | ] 31 | 32 | // normally, you do not need to set eta (step_size) 33 | // XGBoost uses a parallel coordinate descent algorithm (shotgun), 34 | // there could be affection on convergence with parallelization on certain cases 35 | // setting eta to be smaller value, e.g 0.5 can make the optimization more stable 36 | // parameters.append(Parameter("eta", 1)) 37 | 38 | let booster = try Booster( 39 | with: [train, test], 40 | parameters: parameters 41 | ) 42 | 43 | try booster.train( 44 | iterations: 4, 45 | trainingData: train 46 | ) 47 | 48 | let preds = try booster.predict(from: test) 49 | let labels = try test.get(field: .label) 50 | 51 | let error = zip(preds, labels).map { 52 | ($0 > 0.5) == ($1 == 1) ? 0 : 1 53 | }.sum() / Float(preds.count) 54 | 55 | print("Test error: \(error).") 56 | -------------------------------------------------------------------------------- /Sources/XGBoost/Enums.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | /// Typealias for underlying XGBoost version. 4 | public typealias Version = (major: Int, minor: Int, patch: Int) 5 | 6 | /// Tuple holding length of buffer along with it, so it can be easily read. 7 | public typealias RawModel = (length: UInt64, data: UnsafeMutablePointer?>) 8 | public typealias BufferModel = (length: UInt64, data: UnsafeMutablePointer?>) 9 | public typealias SerializedBuffer = (length: UInt64, data: UnsafeMutablePointer?>) 10 | 11 | /// Indicates if iteration should break (stop) or continue (next). 12 | public enum AfterIterationOutput { 13 | case stop, next 14 | } 15 | 16 | /// Currently supported model formats. 17 | public enum ModelFormat: String { 18 | case text, json, dot 19 | } 20 | 21 | /// Currently supported data formats. 22 | public enum DataFormat: String { 23 | case libsvm, csv, binary 24 | } 25 | 26 | /// Predefined names of float fields settable in DMatrix. 27 | /// You can set also another fields using method accepting strings. 28 | public enum FloatField: String { 29 | case label 30 | case weight 31 | case baseMargin = "base_margin" 32 | case labelLowerBound = "label_lower_bound" 33 | case labelUpperBound = "label_upper_bound" 34 | } 35 | 36 | /// Predefined names of uint fields settable in DMatrix. 37 | /// You can set also another fields using method accepting strings. 38 | public enum UIntField: String { 39 | case group 40 | case groupPtr = "group_ptr" 41 | } 42 | 43 | /// Supported types for importance graph. 44 | public enum Importance: String { 45 | case weight 46 | case gain 47 | case cover 48 | case totalGain = "total_gain" 49 | case totalCover = "total_cover" 50 | } 51 | 52 | /// Supported Booster types. 53 | public enum BoosterType: String { 54 | case gbtree 55 | case gblinear 56 | case dart 57 | } 58 | -------------------------------------------------------------------------------- /Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "object": { 3 | "pins": [ 4 | { 5 | "package": "LoggerKit", 6 | "repositoryURL": "https://github.com/kongzii/LoggerKit.git", 7 | "state": { 8 | "branch": null, 9 | "revision": "9b54626f49f97c38227009762cce4156dfa313a5", 10 | "version": "0.0.0" 11 | } 12 | }, 13 | { 14 | "package": "PythonKit", 15 | "repositoryURL": "https://github.com/kongzii/PythonKit", 16 | "state": { 17 | "branch": null, 18 | "revision": "977002d1f5cad96c84360d0ad4496ec614c58975", 19 | "version": "0.0.1" 20 | } 21 | }, 22 | { 23 | "package": "Rainbow", 24 | "repositoryURL": "https://github.com/onevcat/Rainbow", 25 | "state": { 26 | "branch": null, 27 | "revision": "9c52c1952e9b2305d4507cf473392ac2d7c9b155", 28 | "version": "3.1.5" 29 | } 30 | }, 31 | { 32 | "package": "swift-argument-parser", 33 | "repositoryURL": "https://github.com/apple/swift-argument-parser", 34 | "state": { 35 | "branch": null, 36 | "revision": "223d62adc52d51669ae2ee19bdb8b7d9fd6fcd9c", 37 | "version": "0.0.6" 38 | } 39 | }, 40 | { 41 | "package": "swift-log", 42 | "repositoryURL": "https://github.com/apple/swift-log.git", 43 | "state": { 44 | "branch": null, 45 | "revision": "74d7b91ceebc85daf387ebb206003f78813f71aa", 46 | "version": "1.2.0" 47 | } 48 | }, 49 | { 50 | "package": "SwiftPlot", 51 | "repositoryURL": "https://github.com/KarthikRIyer/swiftplot.git", 52 | "state": { 53 | "branch": null, 54 | "revision": "c5424eb8688217029d44e941d3a1d2ccc18c3a82", 55 | "version": "2.0.0" 56 | } 57 | } 58 | ] 59 | }, 60 | "version": 1 61 | } 62 | -------------------------------------------------------------------------------- /Examples/CrossValidation/main.swift: -------------------------------------------------------------------------------- 1 | // This example is based on https://github.com/dmlc/xgboost/blob/master/demo/guide-python/cross_validation.py 2 | 3 | import Foundation 4 | import PythonKit 5 | import XGBoost 6 | 7 | let pandas = Python.import("pandas") 8 | 9 | let data = try DMatrix( 10 | name: "data", 11 | from: "Examples/Data/agaricus.txt.train", 12 | format: .libsvm 13 | ) 14 | 15 | let parameters = [ 16 | Parameter("max_depth", "2"), 17 | Parameter("eta", "1"), 18 | Parameter("seed", "0"), 19 | Parameter("eval_metric", "error"), 20 | Parameter("objective", "binary:logistic"), 21 | ] 22 | 23 | let cv1 = try crossValidationTraining( 24 | data: data, 25 | splits: 5, 26 | iterations: 2, 27 | parameters: parameters, 28 | shuffle: false 29 | ) 30 | print(pandas.DataFrame.from_dict(cv1.results)) 31 | 32 | let cv2 = try crossValidationTraining( 33 | data: data, 34 | splits: 5, 35 | iterations: 10, 36 | parameters: parameters, 37 | earlyStopping: EarlyStopping( 38 | dataName: "data", 39 | metricName: "error", 40 | stoppingRounds: 3 41 | 42 | ), 43 | shuffle: false 44 | ) 45 | print(pandas.DataFrame.from_dict(cv2.results)) 46 | 47 | let objectiveFunction: ObjectiveFunction = { preds, data in 48 | let labels = try data.get(field: .label) 49 | let preds = preds.map { 1.0 / (1.0 + exp(-$0)) } 50 | let grad = (0 ..< labels.count).map { preds[$0] - labels[$0] } 51 | let hess = (0 ..< labels.count).map { preds[$0] * (1.0 - preds[$0]) } 52 | return (grad, hess) 53 | } 54 | 55 | let evaluationFunction: EvaluationFunction = { preds, data in 56 | let labels = try data.get(field: .label) 57 | let error = (0 ..< labels.count).map { 58 | (labels[$0] > 0) != (preds[$0] > 0) ? 1.0 : 0.0 59 | }.sum() / Float(labels.count) 60 | 61 | return ("c-error", "\(error)") 62 | } 63 | 64 | let cv3 = try crossValidationTraining( 65 | data: data, 66 | splits: 5, 67 | iterations: 2, 68 | parameters: parameters, 69 | objectiveFunction: objectiveFunction, 70 | evaluationFunction: evaluationFunction, 71 | shuffle: false 72 | ) 73 | print(pandas.DataFrame.from_dict(cv3.results)) 74 | -------------------------------------------------------------------------------- /Examples/CustomObjective/main.swift: -------------------------------------------------------------------------------- 1 | // This example is based on https://github.com/dmlc/xgboost/blob/master/demo/guide-python/custom_objective.py 2 | 3 | import Foundation 4 | import PythonKit 5 | import XGBoost 6 | 7 | let train = try DMatrix( 8 | name: "train", 9 | from: "Examples/Data/agaricus.txt.train", 10 | format: .libsvm 11 | ) 12 | let test = try DMatrix( 13 | name: "test", 14 | from: "Examples/Data/agaricus.txt.test", 15 | format: .libsvm 16 | ) 17 | 18 | let parameters = [ 19 | Parameter("seed", 0), 20 | Parameter("max_depth", 2), 21 | Parameter("eta", 1), 22 | ] 23 | 24 | // user define objective function, given prediction, return gradient and second 25 | // order gradient this is log likelihood loss 26 | let logRegObjective: ObjectiveFunction = { preds, data in 27 | let labels = try data.get(field: .label) 28 | let preds = preds.map { 1.0 / (1.0 + exp(-$0)) } 29 | let grad = (0 ..< labels.count).map { preds[$0] - labels[$0] } 30 | let hess = (0 ..< labels.count).map { preds[$0] * (1.0 - preds[$0]) } 31 | return (grad, hess) 32 | } 33 | 34 | // user defined evaluation function, return a pair metric_name, result 35 | // 36 | // NOTE: when you do customized loss function, the default prediction value is 37 | // margin. this may make builtin evaluation metric not function properly for 38 | // example, we are doing logistic loss, the prediction is score before logistic 39 | // transformation the builtin evaluation error assumes input is after logistic 40 | // transformation Take this in mind when you use the customization, and maybe 41 | // you need write customized evaluation function 42 | let evaluationFunction: EvaluationFunction = { preds, data in 43 | let labels = try data.get(field: .label) 44 | let error = (0 ..< labels.count).map { 45 | (labels[$0] > 0) != (preds[$0] > 0) ? 1.0 : 0.0 46 | }.sum() / Float(labels.count) 47 | 48 | return ("my-error", "\(error)") 49 | } 50 | 51 | let booster = try Booster( 52 | with: [train, test], 53 | parameters: parameters 54 | ) 55 | 56 | // training with customized objective, we can also do step by step training 57 | // simply look at Train.swift's implementation of train 58 | try booster.train( 59 | iterations: 2, 60 | trainingData: train, 61 | objectiveFunction: logRegObjective, 62 | evaluationFunction: evaluationFunction 63 | ) 64 | 65 | let testPreds = try booster.predict(from: test) 66 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/ExampleTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | @testable import XGBoost 4 | 5 | final class ExampleTests: XCTestCase { 6 | func testReadmeExample1() throws { 7 | // import XGBoost 8 | 9 | // Register your own callback function for log(info) messages 10 | try XGBoost.registerLogCallback { 11 | print("Swifty log:", String(cString: $0!)) 12 | } 13 | 14 | // Create some random features and labels 15 | let randomArray = (0 ..< 1000).map { _ in Float.random(in: 0 ..< 2) } 16 | let labels = (0 ..< 100).map { _ in Float([0, 1].randomElement()!) } 17 | 18 | // Initialize data, DMatrixHandle in the background 19 | let data = try DMatrix( 20 | name: "data", 21 | from: randomArray, 22 | shape: Shape(100, 10), 23 | label: labels, 24 | threads: 1 25 | ) 26 | 27 | // Slice array into train and test 28 | let train = try data.slice(indexes: 0 ..< 90, newName: "train") 29 | let test = try data.slice(indexes: 90 ..< 100, newName: "test") 30 | 31 | // Parameters for Booster, check https://xgboost.readthedocs.io/en/latest/parameter.html 32 | let parameters = [ 33 | Parameter("verbosity", "2"), 34 | Parameter("seed", "0"), 35 | ] 36 | 37 | // Create Booster model, `with` data will be cached 38 | let booster = try Booster( 39 | with: [train, test], 40 | parameters: parameters 41 | ) 42 | 43 | // Train booster, optionally provide callback functions called before and after each iteration 44 | try booster.train( 45 | iterations: 10, 46 | trainingData: train, 47 | evaluationData: [train, test] 48 | ) 49 | 50 | // Predict from test data 51 | let predictions = try booster.predict(from: test) 52 | 53 | // Save 54 | try booster.save(to: "model.xgboost") 55 | 56 | // Assert outputs 57 | XCTAssertEqual(try data.get(field: .label), labels) 58 | XCTAssertEqual(try data.rowCount(), 100) 59 | XCTAssertEqual(try data.columnCount(), 10) 60 | XCTAssertEqual(try train.rowCount(), 90) 61 | XCTAssertEqual(try train.columnCount(), 10) 62 | XCTAssertEqual(try test.rowCount(), 10) 63 | XCTAssertEqual(try test.columnCount(), 10) 64 | } 65 | 66 | static var allTests = [ 67 | ("testReadmeExample1", testReadmeExample1), 68 | ] 69 | } 70 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/Utilities.swift: -------------------------------------------------------------------------------- 1 | import PythonKit 2 | import XCTest 3 | 4 | @testable import XGBoost 5 | 6 | func assertEqual( 7 | _ a: [String: Float], 8 | _ b: [String: Float], 9 | accuracy: Float 10 | ) { 11 | for (key, value) in a { 12 | XCTAssertNotNil(b[key]) 13 | XCTAssertEqual(value, b[key]!, accuracy: accuracy) 14 | } 15 | } 16 | 17 | func assertEqual( 18 | _ a: [String: [Float]], 19 | _ b: [String: [Float]], 20 | accuracy: Float 21 | ) { 22 | for (key, values) in a { 23 | XCTAssertNotNil(b[key]) 24 | for (valueA, valueB) in zip(values, b[key]!) { 25 | XCTAssertEqual(valueA, valueB, accuracy: accuracy) 26 | } 27 | } 28 | } 29 | 30 | func temporaryFile(name: String = "xgboost") -> String { 31 | FileManager.default.temporaryDirectory.appendingPathComponent( 32 | "\(name).\(Int.random(in: 0 ..< Int.max))", isDirectory: false 33 | ).path 34 | } 35 | 36 | func randomDMatrix( 37 | rows: Int = 10, 38 | columns: Int = 5, 39 | threads: Int = 1 40 | ) throws -> DMatrix { 41 | let randomArray = (0 ..< rows * columns).map { _ in Float.random(in: 0 ..< 2) } 42 | let label = (0 ..< rows).map { _ in Float([0, 1].randomElement()!) } 43 | let data = try DMatrix( 44 | name: "data", 45 | from: randomArray, 46 | shape: Shape(rows, columns), 47 | label: label, 48 | threads: threads 49 | ) 50 | return data 51 | } 52 | 53 | func randomTrainedBooster( 54 | iterations: Int = 1 55 | ) throws -> Booster { 56 | let data = try randomDMatrix() 57 | let booster = try Booster( 58 | with: [data], 59 | parameters: [Parameter("tree_method", "exact")] 60 | ) 61 | try booster.train(iterations: iterations, trainingData: data) 62 | 63 | return booster 64 | } 65 | 66 | func pythonXGBoost() throws -> PythonObject { 67 | try Python.attemptImport("xgboost") 68 | } 69 | 70 | func python( 71 | booster: Booster 72 | ) throws -> PythonObject { 73 | let temporaryModelFile = temporaryFile() 74 | try booster.save(to: temporaryModelFile) 75 | let pyXgboost = try Python.attemptImport("xgboost") 76 | 77 | return pyXgboost.Booster( 78 | model_file: temporaryModelFile, 79 | params: ["tree_method": "exact"] 80 | ) 81 | } 82 | 83 | func python( 84 | dmatrix: DMatrix 85 | ) throws -> PythonObject { 86 | let temporaryDataFile = temporaryFile() 87 | try dmatrix.save(to: temporaryDataFile) 88 | let pyXgboost = try Python.attemptImport("xgboost") 89 | 90 | return pyXgboost.DMatrix(data: temporaryDataFile) 91 | } 92 | -------------------------------------------------------------------------------- /Sources/XGBoost/Plot.swift: -------------------------------------------------------------------------------- 1 | import SVGRenderer 2 | import SwiftPlot 3 | 4 | extension Booster { 5 | /// Saves plot with importance based on fitted trees. 6 | /// 7 | /// - Parameter to: File where graph will be saved, .svg extension will be added if rendered remains SVGRenderer. 8 | /// - Parameter featureMap: Path to the feature map, if provided, replaces default f0, f1, ... feature names. 9 | /// - Parameter importance: Type of importance to plot. 10 | /// - Parameter label: Label of graph. 11 | /// - Parameter title: Title of graph. 12 | /// - Parameter xAxisLabel: Label of X-axis. 13 | /// - Parameter yAxisLabel: Label of Y-axis. 14 | /// - Parameter maxNumberOfFeatures: Maximum number of top features displayed on plot. If None, all features will be displayed. 15 | /// - Parameter graphOrientation: Orientaton of ploted graph. 16 | /// - Parameter enableGrid: urn the axes grids on or off. 17 | /// - Parameter size: Size of ploted graph. 18 | /// - Parameter renderer: Renderer to use. 19 | public func saveImportanceGraph( 20 | to fileName: String, 21 | featureMap: String = "", 22 | importance: Importance = .weight, 23 | label: String = "Feature importance", 24 | title: String = "Feature importance", 25 | xAxisLabel: String = "Score", 26 | yAxisLabel: String = "Features", 27 | maxNumberOfFeatures: Int? = nil, 28 | graphOrientation: BarGraph.GraphOrientation = .horizontal, 29 | enableGrid: Bool = true, 30 | size: Size = Size(width: 1000, height: 660), 31 | renderer: Renderer = SVGRenderer() 32 | ) throws { 33 | let (features, gains) = try score(featureMap: featureMap, importance: importance) 34 | 35 | if features.count == 0 { 36 | throw ValueError.runtimeError( 37 | "Score from booster is empty. This maybe caused by having all trees as decision dumps." 38 | ) 39 | } 40 | 41 | var importance: [(name: String, value: Float)] = { 42 | switch importance { 43 | case .weight: 44 | return features.map { ($0, Float($1)) } 45 | case .gain, .cover, .totalGain, .totalCover: 46 | return gains!.map { ($0, $1) } 47 | } 48 | }() 49 | 50 | importance = importance.sorted(by: { $0.value < $1.value }) 51 | 52 | if let maxNumberOfFeatures = maxNumberOfFeatures { 53 | importance = Array(importance[importance.count - maxNumberOfFeatures ..< importance.count]) 54 | } 55 | 56 | let x = importance.map(\.name) 57 | let y = importance.map(\.value) 58 | 59 | var graph = BarGraph(enableGrid: enableGrid) 60 | graph.addSeries(x, y, label: label, graphOrientation: graphOrientation) 61 | graph.plotTitle.title = title 62 | graph.plotLabel.xLabel = xAxisLabel 63 | graph.plotLabel.yLabel = yAxisLabel 64 | try graph.drawGraphAndOutput(size: size, fileName: fileName, renderer: renderer) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/ArrayTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | @testable import XGBoost 4 | 5 | final class ArrayTests: XCTestCase { 6 | func testThrowInvalidFeatureMap() throws { 7 | let path = temporaryFile() 8 | let featureMap = """ 9 | 0 feature1 q 10 | 1 feature2 u 11 | """ 12 | 13 | try featureMap.write( 14 | toFile: path, 15 | atomically: true, 16 | encoding: .utf8 17 | ) 18 | 19 | XCTAssertThrowsError(try [Feature](fromFeatureMap: path)) 20 | } 21 | 22 | func testDiff() { 23 | let values = [1, 2, 3, 4, 10] 24 | let expectedDiff = [1, 1, 1, 6] 25 | XCTAssertEqual(values.diff(), expectedDiff) 26 | } 27 | 28 | func testArrayWithShapeCount() { 29 | let array = ArrayWithShape([1, 2, 3, 4, 5, 6], shape: Shape(3, 2)) 30 | XCTAssertEqual(array.count, 6) 31 | } 32 | 33 | func testArrayWithShapeShape() { 34 | let array = ArrayWithShape([1, 2, 3, 4, 5, 6], shape: Shape(3, 2)) 35 | XCTAssertEqual(array.shape, Shape(3, 2)) 36 | XCTAssertEqual(try array.dataShape(), Shape(3, 2)) 37 | } 38 | 39 | func testArrayWithShapeSubscript() { 40 | var array = ArrayWithShape([1, 2, 3, 4, 5, 6], shape: Shape(3, 2)) 41 | let arrayExpected = ArrayWithShape([2, 2, 3, 4, 5, 6], shape: Shape(3, 2)) 42 | array[0] = 2 43 | XCTAssertEqual(array[0], 2) 44 | XCTAssertEqual(array, arrayExpected) 45 | } 46 | 47 | func testArrayWithShapeDataComfortances() { 48 | let floatArray = ArrayWithShape([1, 2, 3], shape: Shape(3, 1)) 49 | let intArray = ArrayWithShape([1, 2, 3], shape: Shape(3, 1)) 50 | let uintArray = ArrayWithShape([1, 2, 3], shape: Shape(3, 1)) 51 | 52 | XCTAssertEqual(try floatArray.data(), [1, 2, 3]) 53 | XCTAssertEqual(try intArray.data(), [1, 2, 3]) 54 | XCTAssertEqual(try uintArray.data(), [1, 2, 3]) 55 | 56 | XCTAssertTrue(floatArray == [1, 2, 3]) 57 | XCTAssertTrue(intArray == [1, 2, 3]) 58 | XCTAssertTrue(uintArray == [1, 2, 3]) 59 | 60 | XCTAssertTrue([1, 2, 3] == floatArray) 61 | XCTAssertTrue([1, 2, 3] == intArray) 62 | XCTAssertTrue([1, 2, 3] == uintArray) 63 | } 64 | 65 | func testArrayWithShapeIterator() { 66 | let intArray = ArrayWithShape([1, 2, 3], shape: Shape(3, 1)) 67 | 68 | for (index, value) in intArray.enumerated() { 69 | XCTAssertEqual(value, intArray[index]) 70 | } 71 | } 72 | 73 | static var allTests = [ 74 | ("testThrowInvalidFeatureMap", testThrowInvalidFeatureMap), 75 | ("testDiff", testDiff), 76 | ("testArrayWithShapeCount", testArrayWithShapeCount), 77 | ("testArrayWithShapeShape", testArrayWithShapeShape), 78 | ("testArrayWithShapeSubscript", testArrayWithShapeSubscript), 79 | ("testArrayWithShapeDataComfortances", testArrayWithShapeDataComfortances), 80 | ("testArrayWithShapeIterator", testArrayWithShapeIterator), 81 | ] 82 | } 83 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/XGBoostDocsTests.swift: -------------------------------------------------------------------------------- 1 | import CXGBoost 2 | import PythonKit 3 | import XCTest 4 | 5 | #if canImport(TensorFlow) 6 | import TensorFlow 7 | #endif 8 | 9 | @testable import XGBoost 10 | 11 | final class XGBoostDocsTests: XCTestCase { 12 | func testXGBoostIntroDoc() throws { 13 | let numpy = Python.import("numpy") 14 | let pandas = Python.import("pandas") 15 | 16 | let dataFrame = pandas.read_csv("Examples/Data/veterans_lung_cancer.csv") 17 | 18 | #if canImport(TensorFlow) 19 | let tensor = Tensor(shape: TensorShape([2, 3]), scalars: [1, 2, 3, 4, 5, 6]) 20 | let tensorData = try DMatrix(name: "tensorData", from: tensor) 21 | #endif 22 | 23 | let svmData = try DMatrix(name: "train", from: "Examples/Data/data.svm.txt", format: .libsvm) 24 | 25 | let csvData = try DMatrix(name: "train", from: "Examples/Data/data.csv", format: .csv, labelColumn: 0) 26 | 27 | let numpyData = try DMatrix(name: "train", from: numpy.random.rand(5, 10), label: numpy.random.randint(2, size: 5)) 28 | 29 | let pandasDataFrame = pandas.DataFrame(numpy.arange(12).reshape([4, 3]), columns: ["a", "b", "c"]) 30 | let pandasLabel = numpy.random.randint(2, size: 4) 31 | let pandasData = try DMatrix(name: "data", from: pandasDataFrame.values, label: pandasLabel) 32 | 33 | try pandasData.save(to: "train.buffer") 34 | 35 | let dataWithMissingValues = try DMatrix(name: "data", from: pandasDataFrame.values, missingValue: 999.0) 36 | 37 | try dataWithMissingValues.set(field: .weight, values: [Float](repeating: 1, count: try dataWithMissingValues.rowCount())) 38 | 39 | let labelsFromData = try pandasData.get(field: .label) 40 | 41 | let firstBooster = try Booster() 42 | try firstBooster.set(parameter: "tree_method", value: "hist") 43 | 44 | let parameters = [Parameter(name: "tree_method", value: "hist")] 45 | let secondBooster = try Booster(parameters: parameters) 46 | 47 | let trainingData = try DMatrix(name: "train", from: "Examples/Data/data.csv", format: .csv, labelColumn: 0) 48 | let boosterWithCachedData = try Booster(with: [trainingData]) 49 | try boosterWithCachedData.train(iterations: 5, trainingData: trainingData) 50 | 51 | try boosterWithCachedData.save(to: "0001.xgboost") 52 | 53 | let textModel = try boosterWithCachedData.dumped(format: .text) 54 | 55 | let loadedBooster = try Booster(from: "0001.xgboost") 56 | 57 | let testDataNumpy = try DMatrix(name: "test", from: numpy.random.rand(7, 12)) 58 | let predictionNumpy = try loadedBooster.predict(from: testDataNumpy) 59 | 60 | let testData = try DMatrix(name: "test", from: [69.0, 60.0, 7.0, 0, 0, 0, 1, 1, 0, 1, 0, 0], shape: Shape(1, 12)) 61 | let prediction = try loadedBooster.predict(from: testData) 62 | 63 | try boosterWithCachedData.saveImportanceGraph(to: "importance") // .svg extension will be added 64 | 65 | try safe { XGBoosterSaveModel(boosterWithCachedData.booster, "0002.xgboost") } 66 | } 67 | 68 | static var allTests = [ 69 | ("testXGBoostIntroDoc", testXGBoostIntroDoc), 70 | ] 71 | } 72 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/CrossValidationTests.swift: -------------------------------------------------------------------------------- 1 | import PythonKit 2 | import XCTest 3 | 4 | @testable import XGBoost 5 | 6 | final class CrossValidationTests: XCTestCase { 7 | func testBasicCrossValidation() throws { 8 | let randomArray = (0 ..< 10000).map { _ in Float.random(in: 0 ..< 2) } 9 | let label = (0 ..< 1000).map { _ in Float([0, 1].randomElement()!) } 10 | let data = try DMatrix( 11 | name: "data", 12 | from: randomArray, 13 | shape: Shape(1000, 10), 14 | label: label, 15 | threads: 1 16 | ) 17 | 18 | let pyData = try python(dmatrix: data) 19 | 20 | let (results, folds) = try crossValidationTraining( 21 | data: data, 22 | splits: 5, 23 | iterations: 10, 24 | parameters: [ 25 | Parameter("seed", 0), 26 | ], 27 | shuffle: false 28 | ) 29 | 30 | let pyJsonResults = try pythonXGBoost().cv( 31 | params: ["seed": 0], 32 | dtrain: pyData, 33 | nfold: 5, 34 | num_boost_round: 10, 35 | as_pandas: false, 36 | seed: 0, 37 | shuffle: false 38 | ) 39 | 40 | var parsedPyJsonResults = [String: [Float]]() 41 | for name in pyJsonResults { 42 | parsedPyJsonResults["data-\(name)"] = [Float](pyJsonResults[name])! 43 | } 44 | 45 | assertEqual(results, parsedPyJsonResults, accuracy: 1e-6) 46 | } 47 | 48 | func testEarlyStoppingCrossValidation() throws { 49 | let randomArray = (0 ..< 10000).map { _ in Float.random(in: 0 ..< 2) } 50 | let label = (0 ..< 1000).map { _ in Float([0, 1].randomElement()!) } 51 | let data = try DMatrix( 52 | name: "data", 53 | from: randomArray, 54 | shape: Shape(1000, 10), 55 | label: label, 56 | threads: 1 57 | ) 58 | 59 | let pyData = try python(dmatrix: data) 60 | 61 | let (results, folds) = try crossValidationTraining( 62 | data: data, 63 | splits: 5, 64 | iterations: 10, 65 | parameters: [ 66 | Parameter("seed", 0), 67 | ], 68 | earlyStopping: EarlyStopping( 69 | dataName: "data", 70 | metricName: "rmse", 71 | stoppingRounds: 3 72 | ), 73 | shuffle: false 74 | ) 75 | 76 | let pyJsonResults = try pythonXGBoost().cv( 77 | params: ["seed": 0], 78 | dtrain: pyData, 79 | nfold: 5, 80 | num_boost_round: 10, 81 | as_pandas: false, 82 | seed: 0, 83 | callbacks: [try pythonXGBoost().callback.early_stop(3)], 84 | shuffle: false 85 | ) 86 | 87 | var parsedPyJsonResults = [String: [Float]]() 88 | for name in pyJsonResults { 89 | parsedPyJsonResults["data-\(name)"] = [Float](pyJsonResults[name])! 90 | } 91 | 92 | assertEqual(results, parsedPyJsonResults, accuracy: 1e-7) 93 | } 94 | 95 | static var allTests = [ 96 | ("testBasicCrossValidation", testBasicCrossValidation), 97 | ("testEarlyStoppingCrossValidation", testEarlyStoppingCrossValidation), 98 | ] 99 | } 100 | -------------------------------------------------------------------------------- /Examples/AftSurvival/main.swift: -------------------------------------------------------------------------------- 1 | // This example is based on https://github.com/dmlc/xgboost/blob/2c1a439869506532f48d387e7d39beeab358c76b/demo/aft_survival/aft_survival_demo.py. 2 | 3 | import Foundation 4 | import PythonKit 5 | import XGBoost 6 | 7 | let skLearnModelSelection = try Python.attemptImport("sklearn.model_selection") 8 | let pandas = try Python.attemptImport("pandas") 9 | 10 | // Load data from CSV using Pandas 11 | let dataFrame = pandas.read_csv("Examples/Data/veterans_lung_cancer.csv") 12 | 13 | // Split features and labels 14 | let yLowerBound = dataFrame["Survival_label_lower_bound"] 15 | let yUpperBound = dataFrame["Survival_label_upper_bound"] 16 | 17 | // TODO: Replace when following bug is fixed (Padnas drop function is in collision with Swift function) 18 | // Incorrect argument label in call (have 'columns:', expected 'while:') 19 | // Swift.Collection:5:40: note: 'drop(while:)' declared here 20 | // @inlinable public __consuming func drop(while predicate: (Self.Element) throws -> Bool) rethrows -> Self.SubSequence 21 | 22 | // ! let X = dataFrame.drop(columns: ["Survival_label_lower_bound", "Survival_label_upper_bound"]) 23 | let X = Python.getattr(dataFrame, "drop")(columns: ["Survival_label_lower_bound", "Survival_label_upper_bound"]) 24 | 25 | // Split data into training and validation sets 26 | let rs = skLearnModelSelection.ShuffleSplit(n_splits: 2, test_size: 0.7, random_state: 0) 27 | 28 | // ! let splitted = Python.next(rs.split(X)) 29 | let splitted = Python.next(Python.getattr(rs, "split")(X)) 30 | 31 | let (trainIndex, validIndex) = (splitted[0], splitted[1]) 32 | 33 | let trainingData = try DMatrix( 34 | name: "training", 35 | from: X.values[trainIndex] 36 | ) 37 | 38 | try trainingData.set(field: .labelLowerBound, values: yLowerBound.values[trainIndex]) 39 | try trainingData.set(field: .labelUpperBound, values: yUpperBound.values[trainIndex]) 40 | 41 | let validationData = try DMatrix( 42 | name: "validation", 43 | from: X.values[validIndex] 44 | ) 45 | 46 | try validationData.set(field: .labelLowerBound, values: yLowerBound.values[validIndex]) 47 | try validationData.set(field: .labelUpperBound, values: yUpperBound.values[validIndex]) 48 | 49 | // Train gradient boosted trees using AFT loss and metric 50 | let parameters = [ 51 | Parameter("verbosity", "0"), 52 | Parameter("objective", "survival:aft"), 53 | Parameter("eval_metric", "aft-nloglik"), 54 | Parameter("tree_method", "hist"), 55 | Parameter("learning_rate", "0.05"), 56 | Parameter("aft_loss_distribution", "normal"), 57 | Parameter("aft_loss_distribution_scale", "1.20"), 58 | Parameter("max_depth", "6"), 59 | Parameter("lambda", "0.01"), 60 | Parameter("alpha", "0.02"), 61 | ] 62 | 63 | let booster = try Booster( 64 | with: [trainingData, validationData], 65 | parameters: parameters 66 | ) 67 | 68 | try booster.train( 69 | iterations: 10000, 70 | trainingData: trainingData, 71 | evaluationData: [trainingData, validationData] 72 | ) { _, iteration, evaluation, _ in 73 | print("Traning stats at \(iteration): \(evaluation!["training"]!)") 74 | print("Validation stats at \(iteration): \(evaluation!["validation"]!)") 75 | return .next 76 | } 77 | 78 | let predicted = try booster.predict( 79 | from: validationData 80 | ) 81 | 82 | let compare = pandas.DataFrame([ 83 | "Label lower bound": yLowerBound[validIndex], 84 | "Label upper bound": yUpperBound[validIndex], 85 | "Prediced": predicted.makeNumpyArray(), 86 | ]) 87 | 88 | print(compare) 89 | 90 | try booster.save(to: "aft_model.xgboost") 91 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/TrainTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | @testable import XGBoost 4 | 5 | final class TrainTests: XCTestCase { 6 | func testEarlyStoppingMinimize() throws { 7 | let randomArray = (0 ..< 10).map { _ in Float.random(in: 0 ..< 2) } 8 | let label = (0 ..< 10).map { _ in Float([0, 1].randomElement()!) } 9 | let data = try DMatrix( 10 | name: "data", 11 | from: randomArray, 12 | shape: Shape(10, 1), 13 | label: label, 14 | threads: 1 15 | ) 16 | let booster = try Booster( 17 | with: [data], 18 | parameters: [Parameter(name: "seed", value: "0")] 19 | ) 20 | 21 | var scores = [String]() 22 | var iterations = [Int]() 23 | 24 | try booster.train( 25 | iterations: 100, 26 | trainingData: data, 27 | evaluationData: [data], 28 | callbacks: [EarlyStopping( 29 | dataName: "data", 30 | metricName: "rmse", 31 | stoppingRounds: 5, 32 | verbose: true 33 | )] 34 | ) { _, iteration, evaluation, _ in 35 | iterations.append(iteration) 36 | scores.append(evaluation!["data"]!["rmse"]!) 37 | return .next 38 | } 39 | 40 | XCTAssertTrue(scores.count >= 5) 41 | XCTAssertTrue(scores.first! >= scores.last!) 42 | XCTAssertTrue(scores[scores.count - 5 ..< scores.count].allSatisfy { $0 == scores.last! }) 43 | XCTAssertEqual(Double(try booster.attribute(name: "best_score")!)!, Double(scores.last!)!) 44 | XCTAssertEqual(iterations[iterations.count - 5 - 1], Int(try booster.attribute(name: "best_iteration")!)!) 45 | } 46 | 47 | func testEarlyStoppingMaximize() throws { 48 | let randomArray = (0 ..< 10).map { _ in Float.random(in: 0 ..< 2) } 49 | let label = (0 ..< 10).map { _ in Float([0, 1].randomElement()!) } 50 | let data = try DMatrix( 51 | name: "data", 52 | from: randomArray, 53 | shape: Shape(10, 1), 54 | label: label, 55 | threads: 1 56 | ) 57 | let booster = try Booster( 58 | with: [data], 59 | parameters: [ 60 | Parameter(name: "seed", value: "0"), 61 | Parameter(name: "eval_metric", value: "rmse"), 62 | Parameter(name: "eval_metric", value: "auc"), 63 | ] 64 | ) 65 | 66 | var scores = [String]() 67 | var iterations = [Int]() 68 | 69 | try booster.train( 70 | iterations: 100, 71 | trainingData: data, 72 | evaluationData: [data], 73 | callbacks: [EarlyStopping( 74 | dataName: "data", 75 | metricName: "auc", 76 | stoppingRounds: 10, 77 | verbose: true 78 | )] 79 | ) { _, iteration, evaluation, _ in 80 | iterations.append(iteration) 81 | scores.append(evaluation!["data"]!["auc"]!) 82 | return .next 83 | } 84 | 85 | XCTAssertTrue(scores.count >= 10) 86 | XCTAssertTrue(scores.first! <= scores.last!) 87 | XCTAssertTrue(scores[scores.count - 10 ..< scores.count].allSatisfy { $0 == scores.last! }) 88 | XCTAssertEqual(Double(try booster.attribute(name: "best_score")!)!, Double(scores.last!)!) 89 | XCTAssertEqual(iterations[iterations.count - 10 - 1], Int(try booster.attribute(name: "best_iteration")!)!) 90 | } 91 | 92 | static var allTests = [ 93 | ("testEarlyStoppingMinimize", testEarlyStoppingMinimize), 94 | ("testEarlyStoppingMaximize", testEarlyStoppingMaximize), 95 | ] 96 | } 97 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version:5.1 2 | 3 | import PackageDescription 4 | 5 | let package = Package( 6 | name: "SwiftXGBoost", 7 | platforms: [ 8 | .macOS(.v10_13), 9 | ], 10 | products: [ 11 | .library( 12 | name: "XGBoost", 13 | targets: [ 14 | "XGBoost", 15 | ] 16 | ), 17 | .library( 18 | name: "CXGBoost", 19 | targets: [ 20 | "CXGBoost", 21 | ] 22 | ), 23 | ], 24 | dependencies: [ 25 | // TODO: Switch to official PythonKit when its versioned. 26 | .package(url: "https://github.com/kongzii/PythonKit", .exact("0.0.1")), 27 | // .package(url: "https://github.com/pvieito/PythonKit.git", .branch("master")), 28 | .package(url: "https://github.com/KarthikRIyer/swiftplot.git", from: "2.0.0"), 29 | .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"), 30 | ], 31 | targets: [ 32 | .target( 33 | name: "XGBoost", 34 | dependencies: [ 35 | "CXGBoost", 36 | "Logging", 37 | "PythonKit", 38 | "SwiftPlot", 39 | "SVGRenderer", 40 | ] 41 | ), 42 | .systemLibrary( 43 | name: "CXGBoost", 44 | pkgConfig: "xgboost", 45 | providers: [ 46 | .brew(["xgboost"]), 47 | ] 48 | ), 49 | .testTarget( 50 | name: "XGBoostTests", 51 | dependencies: [ 52 | "XGBoost", 53 | "PythonKit", 54 | ] 55 | ), 56 | .target( 57 | name: "AftSurvival", 58 | dependencies: [ 59 | "XGBoost", 60 | "PythonKit", 61 | ], 62 | path: "Examples/AftSurvival" 63 | ), 64 | .target( 65 | name: "CrossValidation", 66 | dependencies: [ 67 | "XGBoost", 68 | "PythonKit", 69 | ], 70 | path: "Examples/CrossValidation" 71 | ), 72 | .target( 73 | name: "BoostFromPrediction", 74 | dependencies: [ 75 | "XGBoost", 76 | "PythonKit", 77 | ], 78 | path: "Examples/BoostFromPrediction" 79 | ), 80 | .target( 81 | name: "CustomObjective", 82 | dependencies: [ 83 | "XGBoost", 84 | "PythonKit", 85 | ], 86 | path: "Examples/CustomObjective" 87 | ), 88 | .target( 89 | name: "EvalsResult", 90 | dependencies: [ 91 | "XGBoost", 92 | "PythonKit", 93 | ], 94 | path: "Examples/EvalsResult" 95 | ), 96 | .target( 97 | name: "ExternalMemory", 98 | dependencies: [ 99 | "XGBoost", 100 | "PythonKit", 101 | ], 102 | path: "Examples/ExternalMemory" 103 | ), 104 | .target( 105 | name: "GeneralizedLinearModel", 106 | dependencies: [ 107 | "XGBoost", 108 | "PythonKit", 109 | ], 110 | path: "Examples/GeneralizedLinearModel" 111 | ), 112 | .target( 113 | name: "PredictFirstNTree", 114 | dependencies: [ 115 | "XGBoost", 116 | "PythonKit", 117 | ], 118 | path: "Examples/PredictFirstNTree" 119 | ), 120 | .target( 121 | name: "PredictLeafIndices", 122 | dependencies: [ 123 | "XGBoost", 124 | "PythonKit", 125 | ], 126 | path: "Examples/PredictLeafIndices" 127 | ), 128 | ] 129 | ) 130 | -------------------------------------------------------------------------------- /Sources/XGBoost/Python.swift: -------------------------------------------------------------------------------- 1 | import PythonKit 2 | 3 | /// The `numpy` Python module. 4 | /// Note: Global variables are lazy, so the following declaration won't produce 5 | /// a Python import error until it is first used. 6 | private let numpy = Python.import("numpy") 7 | private let ctypes = Python.import("ctypes") 8 | 9 | public extension Array where Element: NumpyScalarCompatible { 10 | /// - Precondition: The `numpy` Python package must be installed. 11 | /// - Returns: Numpy array of shape. 12 | func makeNumpyArray(shape: Shape) -> PythonObject { 13 | return numpy.reshape(makeNumpyArray(), shape) 14 | } 15 | } 16 | 17 | public extension ArrayWithShape where Element: NumpyScalarCompatible { 18 | /// - Precondition: The `numpy` Python package must be installed. 19 | /// - Returns: Properly shaped numpy array. 20 | func makeNumpyArray() -> PythonObject { 21 | array.withUnsafeBytes { bytes in 22 | let data = ctypes.cast(Int(bitPattern: bytes.baseAddress), ctypes.POINTER(ctypes.c_float)) 23 | let ndarray = numpy.ctypeslib.as_array(data, shape: shape) 24 | return numpy.copy(ndarray) 25 | } 26 | } 27 | } 28 | 29 | extension Shape { 30 | /// Init shape from PythonObject. 31 | /// 32 | /// - Parameter shape: Python object holding integers that can be converted to [Int]. 33 | public init(_ shape: PythonObject) { 34 | self = [Int](shape)! 35 | } 36 | 37 | /// Init shape from PythonObjects. 38 | /// 39 | /// - Parameter shape: Python object holding integers that can be converted to [Int]. 40 | public init(_ elements: PythonObject...) { 41 | self = elements.map { Int($0)! } 42 | } 43 | } 44 | 45 | /// PythonObject comfortances for protocols that allows using python objects seamlessly with Booster and DMatrix. 46 | extension PythonObject: FloatData, Int32Data, UInt32Data, ShapeData { 47 | /// Comfortance for FloatData. 48 | public func data() throws -> [Float] { 49 | if Bool(Python.isinstance(self, numpy.ndarray))! { 50 | let size = Int(self.size)! 51 | let data = numpy.array(self.reshape(size), copy: false, dtype: numpy.float32) 52 | let contiguousData = numpy.ascontiguousarray(data) 53 | 54 | guard let ptrVal = UInt(contiguousData.__array_interface__["data"].tuple2.0) else { 55 | throw ValueError.runtimeError("Can not get pointer value from numpy object.") 56 | } 57 | 58 | guard let pointer = UnsafePointer(bitPattern: ptrVal) else { 59 | throw ValueError.runtimeError("numpy.ndarray data pointer was nil.") 60 | } 61 | 62 | return Array(UnsafeBufferPointer(start: pointer, count: size)) 63 | } else { 64 | throw ValueError.runtimeError("PythonObject type \(Python.type(self)) [\(self.dtype)] is not supported FloatData.") 65 | } 66 | } 67 | 68 | /// Comfortance for Int32Data. 69 | public func data() throws -> [Int32] { 70 | if Bool(Python.isinstance(self, numpy.ndarray))! { 71 | let size = Int(self.size)! 72 | let data = numpy.array(self.reshape(size), copy: false, dtype: numpy.int32) 73 | let contiguousData = numpy.ascontiguousarray(data) 74 | 75 | guard let ptrVal = UInt(contiguousData.__array_interface__["data"].tuple2.0) else { 76 | throw ValueError.runtimeError("Can not get pointer value from numpy object.") 77 | } 78 | 79 | guard let pointer = UnsafePointer(bitPattern: ptrVal) else { 80 | throw ValueError.runtimeError("numpy.ndarray data pointer was nil.") 81 | } 82 | 83 | return Array(UnsafeBufferPointer(start: pointer, count: size)) 84 | } else { 85 | throw ValueError.runtimeError("PythonObject type \(Python.type(self)) [\(self.dtype)] is not supported Int32Data.") 86 | } 87 | } 88 | 89 | /// Comfortance for UInt32Data. 90 | public func data() throws -> [UInt32] { 91 | if Bool(Python.isinstance(self, numpy.ndarray))! { 92 | let size = Int(self.size)! 93 | let data = numpy.array(self.reshape(size), copy: false, dtype: numpy.uint32) 94 | let contiguousData = numpy.ascontiguousarray(data) 95 | 96 | guard let ptrVal = UInt(contiguousData.__array_interface__["data"].tuple2.0) else { 97 | throw ValueError.runtimeError("Can not get pointer value from numpy object.") 98 | } 99 | 100 | guard let pointer = UnsafePointer(bitPattern: ptrVal) else { 101 | throw ValueError.runtimeError("numpy.ndarray data pointer was nil.") 102 | } 103 | 104 | return Array(UnsafeBufferPointer(start: pointer, count: size)) 105 | } else { 106 | throw ValueError.runtimeError("PythonObject type \(Python.type(self)) [\(self.dtype)] is not supported UInt32Data.") 107 | } 108 | } 109 | 110 | /// Comfortance for ShapeData. 111 | public func dataShape() throws -> Shape { 112 | if Bool(Python.isinstance(self, numpy.ndarray))! { 113 | return Shape(self.shape) 114 | } else { 115 | throw ValueError.runtimeError("PythonObject type \(Python.type(self)) is not supported ShapeData.") 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /Examples/Data/veterans_lung_cancer.csv: -------------------------------------------------------------------------------- 1 | Survival_label_lower_bound,Survival_label_upper_bound,Age_in_years,Karnofsky_score,Months_from_Diagnosis,Celltype=adeno,Celltype=large,Celltype=smallcell,Celltype=squamous,Prior_therapy=no,Prior_therapy=yes,Treatment=standard,Treatment=test 2 | 72.0,72.0,69.0,60.0,7.0,0,0,0,1,1,0,1,0 3 | 411.0,411.0,64.0,70.0,5.0,0,0,0,1,0,1,1,0 4 | 228.0,228.0,38.0,60.0,3.0,0,0,0,1,1,0,1,0 5 | 126.0,126.0,63.0,60.0,9.0,0,0,0,1,0,1,1,0 6 | 118.0,118.0,65.0,70.0,11.0,0,0,0,1,0,1,1,0 7 | 10.0,10.0,49.0,20.0,5.0,0,0,0,1,1,0,1,0 8 | 82.0,82.0,69.0,40.0,10.0,0,0,0,1,0,1,1,0 9 | 110.0,110.0,68.0,80.0,29.0,0,0,0,1,1,0,1,0 10 | 314.0,314.0,43.0,50.0,18.0,0,0,0,1,1,0,1,0 11 | 100.0,inf,70.0,70.0,6.0,0,0,0,1,1,0,1,0 12 | 42.0,42.0,81.0,60.0,4.0,0,0,0,1,1,0,1,0 13 | 8.0,8.0,63.0,40.0,58.0,0,0,0,1,0,1,1,0 14 | 144.0,144.0,63.0,30.0,4.0,0,0,0,1,1,0,1,0 15 | 25.0,inf,52.0,80.0,9.0,0,0,0,1,0,1,1,0 16 | 11.0,11.0,48.0,70.0,11.0,0,0,0,1,0,1,1,0 17 | 30.0,30.0,61.0,60.0,3.0,0,0,1,0,1,0,1,0 18 | 384.0,384.0,42.0,60.0,9.0,0,0,1,0,1,0,1,0 19 | 4.0,4.0,35.0,40.0,2.0,0,0,1,0,1,0,1,0 20 | 54.0,54.0,63.0,80.0,4.0,0,0,1,0,0,1,1,0 21 | 13.0,13.0,56.0,60.0,4.0,0,0,1,0,1,0,1,0 22 | 123.0,inf,55.0,40.0,3.0,0,0,1,0,1,0,1,0 23 | 97.0,inf,67.0,60.0,5.0,0,0,1,0,1,0,1,0 24 | 153.0,153.0,63.0,60.0,14.0,0,0,1,0,0,1,1,0 25 | 59.0,59.0,65.0,30.0,2.0,0,0,1,0,1,0,1,0 26 | 117.0,117.0,46.0,80.0,3.0,0,0,1,0,1,0,1,0 27 | 16.0,16.0,53.0,30.0,4.0,0,0,1,0,0,1,1,0 28 | 151.0,151.0,69.0,50.0,12.0,0,0,1,0,1,0,1,0 29 | 22.0,22.0,68.0,60.0,4.0,0,0,1,0,1,0,1,0 30 | 56.0,56.0,43.0,80.0,12.0,0,0,1,0,0,1,1,0 31 | 21.0,21.0,55.0,40.0,2.0,0,0,1,0,0,1,1,0 32 | 18.0,18.0,42.0,20.0,15.0,0,0,1,0,1,0,1,0 33 | 139.0,139.0,64.0,80.0,2.0,0,0,1,0,1,0,1,0 34 | 20.0,20.0,65.0,30.0,5.0,0,0,1,0,1,0,1,0 35 | 31.0,31.0,65.0,75.0,3.0,0,0,1,0,1,0,1,0 36 | 52.0,52.0,55.0,70.0,2.0,0,0,1,0,1,0,1,0 37 | 287.0,287.0,66.0,60.0,25.0,0,0,1,0,0,1,1,0 38 | 18.0,18.0,60.0,30.0,4.0,0,0,1,0,1,0,1,0 39 | 51.0,51.0,67.0,60.0,1.0,0,0,1,0,1,0,1,0 40 | 122.0,122.0,53.0,80.0,28.0,0,0,1,0,1,0,1,0 41 | 27.0,27.0,62.0,60.0,8.0,0,0,1,0,1,0,1,0 42 | 54.0,54.0,67.0,70.0,1.0,0,0,1,0,1,0,1,0 43 | 7.0,7.0,72.0,50.0,7.0,0,0,1,0,1,0,1,0 44 | 63.0,63.0,48.0,50.0,11.0,0,0,1,0,1,0,1,0 45 | 392.0,392.0,68.0,40.0,4.0,0,0,1,0,1,0,1,0 46 | 10.0,10.0,67.0,40.0,23.0,0,0,1,0,0,1,1,0 47 | 8.0,8.0,61.0,20.0,19.0,1,0,0,0,0,1,1,0 48 | 92.0,92.0,60.0,70.0,10.0,1,0,0,0,1,0,1,0 49 | 35.0,35.0,62.0,40.0,6.0,1,0,0,0,1,0,1,0 50 | 117.0,117.0,38.0,80.0,2.0,1,0,0,0,1,0,1,0 51 | 132.0,132.0,50.0,80.0,5.0,1,0,0,0,1,0,1,0 52 | 12.0,12.0,63.0,50.0,4.0,1,0,0,0,0,1,1,0 53 | 162.0,162.0,64.0,80.0,5.0,1,0,0,0,1,0,1,0 54 | 3.0,3.0,43.0,30.0,3.0,1,0,0,0,1,0,1,0 55 | 95.0,95.0,34.0,80.0,4.0,1,0,0,0,1,0,1,0 56 | 177.0,177.0,66.0,50.0,16.0,0,1,0,0,0,1,1,0 57 | 162.0,162.0,62.0,80.0,5.0,0,1,0,0,1,0,1,0 58 | 216.0,216.0,52.0,50.0,15.0,0,1,0,0,1,0,1,0 59 | 553.0,553.0,47.0,70.0,2.0,0,1,0,0,1,0,1,0 60 | 278.0,278.0,63.0,60.0,12.0,0,1,0,0,1,0,1,0 61 | 12.0,12.0,68.0,40.0,12.0,0,1,0,0,0,1,1,0 62 | 260.0,260.0,45.0,80.0,5.0,0,1,0,0,1,0,1,0 63 | 200.0,200.0,41.0,80.0,12.0,0,1,0,0,0,1,1,0 64 | 156.0,156.0,66.0,70.0,2.0,0,1,0,0,1,0,1,0 65 | 182.0,inf,62.0,90.0,2.0,0,1,0,0,1,0,1,0 66 | 143.0,143.0,60.0,90.0,8.0,0,1,0,0,1,0,1,0 67 | 105.0,105.0,66.0,80.0,11.0,0,1,0,0,1,0,1,0 68 | 103.0,103.0,38.0,80.0,5.0,0,1,0,0,1,0,1,0 69 | 250.0,250.0,53.0,70.0,8.0,0,1,0,0,0,1,1,0 70 | 100.0,100.0,37.0,60.0,13.0,0,1,0,0,0,1,1,0 71 | 999.0,999.0,54.0,90.0,12.0,0,0,0,1,0,1,0,1 72 | 112.0,112.0,60.0,80.0,6.0,0,0,0,1,1,0,0,1 73 | 87.0,inf,48.0,80.0,3.0,0,0,0,1,1,0,0,1 74 | 231.0,inf,52.0,50.0,8.0,0,0,0,1,0,1,0,1 75 | 242.0,242.0,70.0,50.0,1.0,0,0,0,1,1,0,0,1 76 | 991.0,991.0,50.0,70.0,7.0,0,0,0,1,0,1,0,1 77 | 111.0,111.0,62.0,70.0,3.0,0,0,0,1,1,0,0,1 78 | 1.0,1.0,65.0,20.0,21.0,0,0,0,1,0,1,0,1 79 | 587.0,587.0,58.0,60.0,3.0,0,0,0,1,1,0,0,1 80 | 389.0,389.0,62.0,90.0,2.0,0,0,0,1,1,0,0,1 81 | 33.0,33.0,64.0,30.0,6.0,0,0,0,1,1,0,0,1 82 | 25.0,25.0,63.0,20.0,36.0,0,0,0,1,1,0,0,1 83 | 357.0,357.0,58.0,70.0,13.0,0,0,0,1,1,0,0,1 84 | 467.0,467.0,64.0,90.0,2.0,0,0,0,1,1,0,0,1 85 | 201.0,201.0,52.0,80.0,28.0,0,0,0,1,0,1,0,1 86 | 1.0,1.0,35.0,50.0,7.0,0,0,0,1,1,0,0,1 87 | 30.0,30.0,63.0,70.0,11.0,0,0,0,1,1,0,0,1 88 | 44.0,44.0,70.0,60.0,13.0,0,0,0,1,0,1,0,1 89 | 283.0,283.0,51.0,90.0,2.0,0,0,0,1,1,0,0,1 90 | 15.0,15.0,40.0,50.0,13.0,0,0,0,1,0,1,0,1 91 | 25.0,25.0,69.0,30.0,2.0,0,0,1,0,1,0,0,1 92 | 103.0,inf,36.0,70.0,22.0,0,0,1,0,0,1,0,1 93 | 21.0,21.0,71.0,20.0,4.0,0,0,1,0,1,0,0,1 94 | 13.0,13.0,62.0,30.0,2.0,0,0,1,0,1,0,0,1 95 | 87.0,87.0,60.0,60.0,2.0,0,0,1,0,1,0,0,1 96 | 2.0,2.0,44.0,40.0,36.0,0,0,1,0,0,1,0,1 97 | 20.0,20.0,54.0,30.0,9.0,0,0,1,0,0,1,0,1 98 | 7.0,7.0,66.0,20.0,11.0,0,0,1,0,1,0,0,1 99 | 24.0,24.0,49.0,60.0,8.0,0,0,1,0,1,0,0,1 100 | 99.0,99.0,72.0,70.0,3.0,0,0,1,0,1,0,0,1 101 | 8.0,8.0,68.0,80.0,2.0,0,0,1,0,1,0,0,1 102 | 99.0,99.0,62.0,85.0,4.0,0,0,1,0,1,0,0,1 103 | 61.0,61.0,71.0,70.0,2.0,0,0,1,0,1,0,0,1 104 | 25.0,25.0,70.0,70.0,2.0,0,0,1,0,1,0,0,1 105 | 95.0,95.0,61.0,70.0,1.0,0,0,1,0,1,0,0,1 106 | 80.0,80.0,71.0,50.0,17.0,0,0,1,0,1,0,0,1 107 | 51.0,51.0,59.0,30.0,87.0,0,0,1,0,0,1,0,1 108 | 29.0,29.0,67.0,40.0,8.0,0,0,1,0,1,0,0,1 109 | 24.0,24.0,60.0,40.0,2.0,1,0,0,0,1,0,0,1 110 | 18.0,18.0,69.0,40.0,5.0,1,0,0,0,0,1,0,1 111 | 83.0,inf,57.0,99.0,3.0,1,0,0,0,1,0,0,1 112 | 31.0,31.0,39.0,80.0,3.0,1,0,0,0,1,0,0,1 113 | 51.0,51.0,62.0,60.0,5.0,1,0,0,0,1,0,0,1 114 | 90.0,90.0,50.0,60.0,22.0,1,0,0,0,0,1,0,1 115 | 52.0,52.0,43.0,60.0,3.0,1,0,0,0,1,0,0,1 116 | 73.0,73.0,70.0,60.0,3.0,1,0,0,0,1,0,0,1 117 | 8.0,8.0,66.0,50.0,5.0,1,0,0,0,1,0,0,1 118 | 36.0,36.0,61.0,70.0,8.0,1,0,0,0,1,0,0,1 119 | 48.0,48.0,81.0,10.0,4.0,1,0,0,0,1,0,0,1 120 | 7.0,7.0,58.0,40.0,4.0,1,0,0,0,1,0,0,1 121 | 140.0,140.0,63.0,70.0,3.0,1,0,0,0,1,0,0,1 122 | 186.0,186.0,60.0,90.0,3.0,1,0,0,0,1,0,0,1 123 | 84.0,84.0,62.0,80.0,4.0,1,0,0,0,0,1,0,1 124 | 19.0,19.0,42.0,50.0,10.0,1,0,0,0,1,0,0,1 125 | 45.0,45.0,69.0,40.0,3.0,1,0,0,0,1,0,0,1 126 | 80.0,80.0,63.0,40.0,4.0,1,0,0,0,1,0,0,1 127 | 52.0,52.0,45.0,60.0,4.0,0,1,0,0,1,0,0,1 128 | 164.0,164.0,68.0,70.0,15.0,0,1,0,0,0,1,0,1 129 | 19.0,19.0,39.0,30.0,4.0,0,1,0,0,0,1,0,1 130 | 53.0,53.0,66.0,60.0,12.0,0,1,0,0,1,0,0,1 131 | 15.0,15.0,63.0,30.0,5.0,0,1,0,0,1,0,0,1 132 | 43.0,43.0,49.0,60.0,11.0,0,1,0,0,0,1,0,1 133 | 340.0,340.0,64.0,80.0,10.0,0,1,0,0,0,1,0,1 134 | 133.0,133.0,65.0,75.0,1.0,0,1,0,0,1,0,0,1 135 | 111.0,111.0,64.0,60.0,5.0,0,1,0,0,1,0,0,1 136 | 231.0,231.0,67.0,70.0,18.0,0,1,0,0,0,1,0,1 137 | 378.0,378.0,65.0,80.0,4.0,0,1,0,0,1,0,0,1 138 | 49.0,49.0,37.0,30.0,3.0,0,1,0,0,1,0,0,1 139 | -------------------------------------------------------------------------------- /Sources/XGBoost/Array.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | 3 | public extension Array where Element == Feature { 4 | /// Save feature map compatible with XGBoost`s inputs. 5 | /// 6 | /// - Parameter to: Path where feature map will be saved. 7 | func saveFeatureMap(to path: String) throws { 8 | try enumerated() 9 | .map { "\($0) \($1.name) \($1.type.rawValue)" } 10 | .joined(separator: "\n") 11 | .write( 12 | toFile: path, 13 | atomically: true, 14 | encoding: .utf8 15 | ) 16 | } 17 | 18 | /// Load previously saved feature map. 19 | /// 20 | /// - Parameter from: Path to the feature map. 21 | init(fromFeatureMap path: String) throws { 22 | let content = try String(contentsOfFile: path) 23 | var features = [Feature]() 24 | 25 | for line in content.components(separatedBy: .newlines) { 26 | let splitted = line.components(separatedBy: " ") 27 | let (name, type) = (splitted[1], splitted[2]) 28 | 29 | guard let featureType = FeatureType(rawValue: type) else { 30 | throw ValueError.runtimeError("Invalid feature type \(type).") 31 | } 32 | 33 | features.append(Feature(name: name, type: featureType)) 34 | } 35 | 36 | self.init(features) 37 | } 38 | } 39 | 40 | extension Array: FloatData where Element == Float { 41 | /// - Returns: self 42 | public func data() throws -> [Float] { 43 | self 44 | } 45 | } 46 | 47 | extension Array: Int32Data where Element == Int32 { 48 | /// - Returns: self 49 | public func data() throws -> [Int32] { 50 | self 51 | } 52 | } 53 | 54 | extension Array: UInt32Data where Element == UInt32 { 55 | /// - Returns: self 56 | public func data() throws -> [UInt32] { 57 | self 58 | } 59 | } 60 | 61 | public extension Array { 62 | /// - Parameter into: Number of chunks. 63 | /// - Returns: Self splitted into n equally sized chunks. 64 | func chunked(into chunks: Int) -> [[Element]] { 65 | let size = (count / chunks) 66 | + Int((Float(count % chunks) / Float(chunks)).rounded(.up)) 67 | 68 | return stride(from: 0, to: count, by: size).map { 69 | Array(self[$0 ..< Swift.min($0 + size, count)]) 70 | } 71 | } 72 | } 73 | 74 | public extension Array where Element: AdditiveArithmetic { 75 | /// - Returns: Array calucated from self as [self[i + 1] - self[i]] for i = 0 ..< count - 1. 76 | func diff() -> [Element] { 77 | var result = [Element]() 78 | 79 | for i in 0 ..< count - 1 { 80 | result.append(self[i + 1] - self[i]) 81 | } 82 | 83 | return result 84 | } 85 | } 86 | 87 | public extension Array where Element: FloatingPoint { 88 | /// - Parameter from: Initial value for sum. 89 | /// - Returns: Sum of elements. 90 | func sum( 91 | from initialValue: Element = 0 92 | ) -> Element { 93 | reduce(initialValue, +) 94 | } 95 | 96 | /// - Parameter sum: Precalculated sum. 97 | /// - Returns: Mean of elements. 98 | func mean( 99 | sum: Element? = nil 100 | ) -> Element { 101 | sum ?? self.sum() / Element(count) 102 | } 103 | 104 | /// - Parameter sum: Precalculated sum. 105 | /// - Parameter mean: Precalculated mean. 106 | /// - Parameter ddof: DDOF. 107 | /// - Returns: STD of elements. 108 | func std( 109 | sum: Element? = nil, 110 | mean: Element? = nil, 111 | ddof: Int = 0 112 | ) -> Element { 113 | let mean = mean ?? self.mean(sum: sum) 114 | let v = reduce(0) { $0 + ($1 - mean) * ($1 - mean) } 115 | return (v / Element(count - ddof)).squareRoot() 116 | } 117 | } 118 | 119 | public extension Array where Element == AfterIterationOutput { 120 | /// - Returns: Whether array contains `AfterIterationOutput.stop`. 121 | var willStop: Bool { 122 | contains(where: { $0 == .stop }) 123 | } 124 | } 125 | 126 | public struct ArrayWithShape: Equatable, Sequence, ShapeData where Element: Equatable { 127 | public var array: [Element] 128 | public var shape: Shape 129 | 130 | public var count: Int { 131 | array.count 132 | } 133 | 134 | public init(_ array: [Element], shape: Shape) { 135 | self.array = array 136 | self.shape = shape 137 | } 138 | 139 | public func dataShape() throws -> Shape { 140 | shape 141 | } 142 | 143 | public subscript(index: Int) -> Element { 144 | get { 145 | array[index] 146 | } 147 | set(newValue) { 148 | array[index] = newValue 149 | } 150 | } 151 | 152 | public static func == (lhs: Self, rhs: Self) -> Bool { 153 | lhs.array == rhs.array && lhs.shape == rhs.shape 154 | } 155 | 156 | public static func == (lhs: [Element], rhs: Self) -> Bool { 157 | lhs == rhs.array 158 | } 159 | 160 | public static func == (lhs: Self, rhs: [Element]) -> Bool { 161 | lhs.array == rhs 162 | } 163 | 164 | public func makeIterator() -> ArrayWithShapeIterator { 165 | ArrayWithShapeIterator(self) 166 | } 167 | } 168 | 169 | public struct ArrayWithShapeIterator: IteratorProtocol where Element: Equatable { 170 | var index = 0 171 | let arrayWithShape: ArrayWithShape 172 | 173 | init(_ arrayWithShape: ArrayWithShape) { 174 | self.arrayWithShape = arrayWithShape 175 | } 176 | 177 | public mutating func next() -> Element? { 178 | if index == arrayWithShape.count { 179 | return nil 180 | } 181 | 182 | defer { index += 1 } 183 | return arrayWithShape[index] 184 | } 185 | } 186 | 187 | extension ArrayWithShape: FloatData where Element == Float { 188 | /// - Returns: [Float] 189 | public func data() throws -> [Float] { 190 | array 191 | } 192 | } 193 | 194 | extension ArrayWithShape: Int32Data where Element == Int32 { 195 | /// - Returns: [Int32] 196 | public func data() throws -> [Int32] { 197 | array 198 | } 199 | } 200 | 201 | extension ArrayWithShape: UInt32Data where Element == UInt32 { 202 | /// - Returns: [UInt32] 203 | public func data() throws -> [UInt32] { 204 | array 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/DMatrixTests.swift: -------------------------------------------------------------------------------- 1 | import XCTest 2 | 3 | @testable import XGBoost 4 | 5 | final class DMatrixTests: XCTestCase { 6 | func testSetNilFeatures() throws { 7 | var data = try randomDMatrix(columns: 1) 8 | try data.features() 9 | XCTAssertEqual(data._features, [Feature(name: "f0", type: .quantitative)]) 10 | try data.set(features: nil) 11 | XCTAssertEqual(data._features, nil) 12 | } 13 | 14 | func testInvalidCacheFormat() throws { 15 | XCTAssertThrowsError(try DMatrix( 16 | name: "data", 17 | from: "file", 18 | format: .csv 19 | )) 20 | } 21 | 22 | func testInvalidDMatrixShape() throws { 23 | XCTAssertThrowsError(try DMatrix( 24 | name: "data", 25 | from: [1, 2, 3, 3, 4, 6], 26 | shape: Shape(6) 27 | )) 28 | } 29 | 30 | func testIsFlat() throws { 31 | let data = ArrayWithShape([10, 3, 1], shape: Shape(3)) 32 | XCTAssertEqual(try data.isFlat(), true) 33 | } 34 | 35 | func testInitWithWeight() throws { 36 | let data = try DMatrix( 37 | name: "data", 38 | from: [1, 2, 3, 3, 4, 6], 39 | shape: Shape(3, 2), 40 | weight: [1, 2, 3] 41 | ) 42 | XCTAssertEqual(try data.get(field: .weight), [1, 2, 3]) 43 | } 44 | 45 | func testInitWithBaseMargin() throws { 46 | let data = try DMatrix( 47 | name: "data", 48 | from: [1, 2, 3, 3, 4, 6], 49 | shape: Shape(3, 2), 50 | baseMargin: [1, 2, 3] 51 | ) 52 | XCTAssertEqual(try data.get(field: .baseMargin), [1, 2, 3]) 53 | } 54 | 55 | func testInitWithBound() throws { 56 | let data = try DMatrix( 57 | name: "data", 58 | from: [1, 2, 3, 3, 4, 6], 59 | shape: Shape(3, 2), 60 | labelLowerBound: [1, 2, 3], 61 | labelUpperBound: [2, 3, 4] 62 | ) 63 | XCTAssertEqual(try data.get(field: .labelLowerBound), [1, 2, 3]) 64 | XCTAssertEqual(try data.get(field: .labelUpperBound), [2, 3, 4]) 65 | } 66 | 67 | func testCounts() throws { 68 | let randomArray = (0 ..< 10).map { _ in Float.random(in: 0 ..< 2) } 69 | let data = try DMatrix( 70 | name: "data", 71 | from: randomArray, 72 | shape: Shape(5, 2), 73 | threads: 1 74 | ) 75 | XCTAssertEqual(try data.rowCount(), 5) 76 | XCTAssertEqual(try data.columnCount(), 2) 77 | XCTAssertEqual(try data.shape(), [5, 2]) 78 | } 79 | 80 | func testFromCSVFile() throws { 81 | let csvFile = temporaryFile() 82 | 83 | let fileContent = """ 84 | 1.0,2.0,3.0 85 | 0.0,4.0,5.0 86 | 1.0,6.0,7.0 87 | """ 88 | 89 | try fileContent.write( 90 | toFile: csvFile, 91 | atomically: true, 92 | encoding: .utf8 93 | ) 94 | 95 | let data = try DMatrix( 96 | name: "test", 97 | from: csvFile, 98 | format: .csv, 99 | labelColumn: 0 100 | ) 101 | 102 | XCTAssertEqual(try data.get(field: .label), [1.0, 0.0, 1.0]) 103 | XCTAssertEqual(try data.rowCount(), 3) 104 | XCTAssertEqual(try data.columnCount(), 2) 105 | XCTAssertEqual(try data.shape(), [3, 2]) 106 | } 107 | 108 | func testSaveAndLoadBinary() throws { 109 | let data = try DMatrix( 110 | name: "data", 111 | from: [1, 2, 3, 3, 4, 6], 112 | shape: Shape(3, 2), 113 | label: [1, 0, 1] 114 | ) 115 | 116 | let path = temporaryFile() 117 | 118 | try data.save(to: path) 119 | 120 | let loadedData = try DMatrix( 121 | name: "loaded", 122 | from: path, 123 | format: .binary 124 | ) 125 | 126 | XCTAssertEqual(try data.get(field: .label), try loadedData.get(field: .label)) 127 | XCTAssertEqual(try data.rowCount(), try loadedData.rowCount()) 128 | XCTAssertEqual(try data.columnCount(), try loadedData.columnCount()) 129 | XCTAssertEqual(try data.shape(), try loadedData.shape()) 130 | XCTAssertEqual(try data.shape(), [3, 2]) 131 | } 132 | 133 | func testSaveAndLoadFeatureMap() throws { 134 | let features = [ 135 | Feature("x", .quantitative), 136 | Feature("y", .quantitative), 137 | ] 138 | 139 | let path = temporaryFile() 140 | 141 | try features.saveFeatureMap(to: path) 142 | 143 | let loadedFeatures = try [Feature](fromFeatureMap: path) 144 | 145 | XCTAssertEqual(features, loadedFeatures) 146 | 147 | let data = try DMatrix( 148 | name: "data", 149 | from: [1, 2, 3, 3, 4, 6], 150 | shape: Shape(3, 2), 151 | label: [1, 0, 1] 152 | ) 153 | try data.loadFeatureMap(from: path) 154 | 155 | XCTAssertEqual(try data.features(), loadedFeatures) 156 | } 157 | 158 | func testSlice() throws { 159 | let data = try DMatrix( 160 | name: "data", 161 | from: [1, 2, 3, 3, 4, 6], 162 | shape: Shape(3, 2), 163 | label: [1, 0, 1] 164 | ) 165 | 166 | let firstSlice = try data.slice( 167 | indexes: [0, 2], 168 | newName: "firstSlice" 169 | ) 170 | XCTAssertEqual(firstSlice.name, "firstSlice") 171 | XCTAssertEqual(try firstSlice.shape(), [2, 2]) 172 | XCTAssertEqual(try firstSlice.get(field: .label), [1, 1]) 173 | 174 | let secondSlice = try data.slice( 175 | indexes: [1], 176 | newName: "secondSlice" 177 | ) 178 | XCTAssertEqual(secondSlice.name, "secondSlice") 179 | XCTAssertEqual(try secondSlice.shape(), [1, 2]) 180 | XCTAssertEqual(try secondSlice.get(field: .label), [0]) 181 | 182 | let thirdSlice = try data.slice( 183 | indexes: 0 ..< 2, 184 | newName: "thirdSlice" 185 | ) 186 | XCTAssertEqual(thirdSlice.name, "thirdSlice") 187 | XCTAssertEqual(try thirdSlice.shape(), [2, 2]) 188 | XCTAssertEqual(try thirdSlice.get(field: .label), [1, 0]) 189 | } 190 | 191 | func testSetGetFloatInfo() throws { 192 | let data = try DMatrix( 193 | name: "data", 194 | from: [1, 2, 3, 3, 4, 6], 195 | shape: Shape(3, 2) 196 | ) 197 | 198 | try data.set(field: .label, values: [1, 2, 3]) 199 | XCTAssertEqual(try data.get(field: .label), [1, 2, 3]) 200 | 201 | try data.set(field: .weight, values: [1, 3, 3]) 202 | XCTAssertEqual(try data.get(field: .weight), [1, 3, 3]) 203 | } 204 | 205 | func testSetGetUIntInfo() throws { 206 | let data = try DMatrix( 207 | name: "data", 208 | from: [1, 2, 3, 3, 4, 6], 209 | shape: Shape(6, 1) 210 | ) 211 | 212 | try data.set(field: .group, values: [1, 1, 1, 3]) 213 | XCTAssertEqual(try data.get(field: .groupPtr), [0, 1, 2, 3, 6]) 214 | } 215 | 216 | static var allTests = [ 217 | ("testCounts", testCounts), 218 | ("testFromCSVFile", testFromCSVFile), 219 | ("testSaveAndLoadBinary", testSaveAndLoadBinary), 220 | ("testSaveAndLoadFeatureMap", testSaveAndLoadFeatureMap), 221 | ("testSlice", testSlice), 222 | ("testSetGetFloatInfo", testSetGetFloatInfo), 223 | ("testSetGetUIntInfo", testSetGetUIntInfo), 224 | ] 225 | } 226 | -------------------------------------------------------------------------------- /doc/swift_intro.rst: -------------------------------------------------------------------------------- 1 | ########################### 2 | Swift Package Introduction 3 | ########################### 4 | This document gives a basic walkthrough of xgboost swift package. 5 | 6 | 7 | **List of other helpful links** 8 | 9 | * `Swift examples `_ 10 | * `Swift API Reference `_ 11 | 12 | 13 | Install XGBoost 14 | --------------- 15 | To install XGBoost, please follow instructions at `GitHub `_. 16 | 17 | 18 | Include in your project 19 | ----------------------- 20 | `SwiftXGBoost `_ uses `Swift Package Manager `_, to use it in your project, simply add it as a dependency in your Package.swift file: 21 | 22 | .. code-block:: swift 23 | 24 | .package(url: "https://github.com/kongzii/SwiftXGBoost.git", from: "0.7.0") 25 | 26 | 27 | Python compatibility 28 | -------------------- 29 | With `PythonKit `_ package, you can import Python modules: 30 | 31 | .. code-block:: swift 32 | 33 | let numpy = Python.import("numpy") 34 | let pandas = Python.import("pandas") 35 | 36 | And use them in the same way as in Python: 37 | 38 | .. code-block:: swift 39 | 40 | let dataFrame = pandas.read_csv("Examples/Data/veterans_lung_cancer.csv") 41 | 42 | And then use them with `SwiftXGBoost `_, 43 | check `AftSurvival `_ for a complete example. 44 | 45 | 46 | TensorFlow compatibility 47 | ------------------------ 48 | If you are using `S4TF toolchains `_, you can utilize tensors directly: 49 | 50 | .. code-block:: swift 51 | 52 | let tensor = Tensor(shape: TensorShape([2, 3]), scalars: [1, 2, 3, 4, 5, 6]) 53 | let tensorData = try DMatrix(name: "tensorData", from: tensor) 54 | 55 | 56 | Data Interface 57 | -------------- 58 | The XGBoost swift package is currently able to load data from: 59 | 60 | - LibSVM text format file 61 | - Comma-separated values (CSV) file 62 | - NumPy 2D array 63 | - `Swift for Tensorflow `_ 2D Tensor 64 | - XGBoost binary buffer file 65 | 66 | The data is stored in a `DMatrix `_ class. 67 | 68 | To load a libsvm text file into `DMatrix `_ class: 69 | 70 | .. code-block:: swift 71 | 72 | let svmData = try DMatrix(name: "train", from: "Examples/Data/data.svm.txt", format: .libsvm) 73 | 74 | To load a CSV file into `DMatrix `_: 75 | 76 | .. code-block:: swift 77 | 78 | // labelColumn specifies the index of the column containing the true label 79 | let csvData = try DMatrix(name: "train", from: "Examples/Data/data.csv", format: .csv, labelColumn: 0) 80 | 81 | .. note:: Use Pandas to load CSV files with headers. 82 | 83 | Currently, the DMLC data parser cannot parse CSV files with headers. Use Pandas (see below) to read CSV files with headers. 84 | 85 | To load a NumPy array into `DMatrix `_: 86 | 87 | .. code-block:: swift 88 | 89 | let numpyData = try DMatrix(name: "train", from: numpy.random.rand(5, 10), label: numpy.random.randint(2, size: 5)) 90 | 91 | To load a Pandas data frame into `DMatrix `_: 92 | 93 | .. code-block:: swift 94 | 95 | let pandasDataFrame = pandas.DataFrame(numpy.arange(12).reshape([4, 3]), columns: ["a", "b", "c"]) 96 | let pandasLabel = numpy.random.randint(2, size: 4) 97 | let pandasData = try DMatrix(name: "data", from: pandasDataFrame.values, label: pandasLabel) 98 | 99 | Saving `DMatrix `_ into an XGBoost binary file will make loading faster: 100 | 101 | .. code-block:: swift 102 | 103 | try pandasData.save(to: "train.buffer") 104 | 105 | Missing values can be replaced by a default value in the `DMatrix `_ constructor: 106 | 107 | .. code-block:: swift 108 | 109 | let dataWithMissingValues = try DMatrix(name: "data", from: pandasDataFrame.values, missingValue: 999.0) 110 | 111 | Various `float fields `_ and `uint fields `_ can be set when needed: 112 | 113 | .. code-block:: swift 114 | 115 | try dataWithMissingValues.set(field: .weight, values: [Float](repeating: 1, count: try dataWithMissingValues.rowCount())) 116 | 117 | And returned: 118 | 119 | .. code-block:: swift 120 | 121 | let labelsFromData = try pandasData.get(field: .label) 122 | 123 | 124 | Setting Parameters 125 | ------------------ 126 | Parameters for `Booster `_ can also be set. 127 | 128 | Using the set method: 129 | 130 | .. code-block:: swift 131 | 132 | let firstBooster = try Booster() 133 | try firstBooster.set(parameter: "tree_method", value: "hist") 134 | 135 | Or as a list at initialization: 136 | 137 | .. code-block:: swift 138 | 139 | let parameters = [Parameter(name: "tree_method", value: "hist")] 140 | let secondBooster = try Booster(parameters: parameters) 141 | 142 | 143 | Training 144 | -------- 145 | Training a model requires a booster and a dataset. 146 | 147 | .. code-block:: swift 148 | 149 | let trainingData = try DMatrix(name: "train", from: "Examples/Data/data.csv", format: .csv, labelColumn: 0) 150 | let boosterWithCachedData = try Booster(with: [trainingData]) 151 | try boosterWithCachedData.train(iterations: 5, trainingData: trainingData) 152 | 153 | After training, the model can be saved: 154 | 155 | .. code-block:: swift 156 | 157 | try boosterWithCachedData.save(to: "0001.xgboost") 158 | 159 | The model can also be dumped to a text: 160 | 161 | .. code-block:: swift 162 | 163 | let textModel = try boosterWithCachedData.dumped(format: .text) 164 | 165 | A saved model can be loaded as follows: 166 | 167 | .. code-block:: swift 168 | 169 | let loadedBooster = try Booster(from: "0001.xgboost") 170 | 171 | 172 | Prediction 173 | ---------- 174 | A model that has been trained or loaded can perform predictions on data sets. 175 | 176 | From Numpy array: 177 | 178 | .. code-block:: swift 179 | 180 | let testDataNumpy = try DMatrix(name: "test", from: numpy.random.rand(7, 12)) 181 | let predictionNumpy = try loadedBooster.predict(from: testDataNumpy) 182 | 183 | From Swift array: 184 | 185 | .. code-block:: swift 186 | 187 | let testData = try DMatrix(name: "test", from: [69.0,60.0,7.0,0,0,0,1,1,0,1,0,0], shape: Shape(1, 12)) 188 | let prediction = try loadedBooster.predict(from: testData) 189 | 190 | 191 | Plotting 192 | -------- 193 | You can also save the plot of importance into a file: 194 | 195 | .. code-block:: swift 196 | 197 | try boosterWithCachedData.saveImportanceGraph(to: "importance") // .svg extension will be added 198 | 199 | 200 | C API 201 | -------- 202 | Both `Booster `_ and `DMatrix `_ are exposing pointers to the underlying C. 203 | 204 | You can import a C-API library: 205 | 206 | .. code-block:: swift 207 | 208 | import CXGBoost 209 | 210 | And use it directly in your Swift code: 211 | 212 | .. code-block:: swift 213 | 214 | try safe { XGBoosterSaveModel(boosterWithCachedData.booster, "0002.xgboost") } 215 | 216 | `safe` is a helper function that will throw an error if C-API call fails. 217 | 218 | 219 | More 220 | -------- 221 | For more details and examples, check out `GitHub repository `_. 222 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![codecov](https://codecov.io/gh/kongzii/SwiftXGBoost/branch/master/graph/badge.svg)](https://codecov.io/gh/kongzii/SwiftXGBoost) 2 | [![Platform](https://img.shields.io/badge/platform-linux%2Cmacos-lightgrey)](https://kongzii.github.io/SwiftXGBoost/) 3 | [![Swift Version](https://img.shields.io/badge/Swift-5.2-green.svg)]() 4 | [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](http://makeapullrequest.com) 5 | ![Ubuntu](https://github.com/kongzii/SwiftXGBoost/workflows/Ubuntu/badge.svg) 6 | ![MacOS](https://github.com/kongzii/SwiftXGBoost/workflows/MacOS/badge.svg) 7 | 8 | # XGBoost for Swift 9 | 10 | Bindings for [the XGBoost system library](https://en.wikipedia.org/wiki/XGBoost). 11 | The aim of this package is to mimic [XGBoost Python bindings](https://xgboost.readthedocs.io/en/latest/python/python_intro.html) but, at the same time, utilize the power of Swift and C compatibility. Some things thus behave differently but should provide you maximum flexibility over XGBoost. 12 | 13 | Check out: 14 | 15 | - [Examples](https://github.com/kongzii/SwiftXGBoost/tree/master/Examples) 16 | - [Code documentation](https://kongzii.github.io/SwiftXGBoost/) 17 | - [ReadTheDocs](https://swiftxgboost.readthedocs.io/) 18 | 19 | ## Installation 20 | 21 | ### System library dependency 22 | 23 | #### Linux 24 | 25 | Install XGBoost from sources 26 | 27 | ``` 28 | git clone https://github.com/dmlc/xgboost 29 | cd xgboost 30 | git checkout tags/v1.1.1 31 | git submodule update --init --recursive 32 | mkdir build 33 | cd build 34 | cmake .. 35 | make 36 | make install 37 | ldconfig 38 | ``` 39 | 40 | Or you can use provided installation script 41 | 42 | ``` 43 | ./install.sh 44 | ``` 45 | 46 | #### macOS 47 | 48 | You can build and install similarly as on Linux, or just use brew 49 | 50 | ``` 51 | brew install xgboost 52 | ``` 53 | 54 | ##### Note 55 | 56 | Before version 1.1.1, XGBoost did not create pkg-config. This was fixed with PR [Add pkgconfig to cmake #5744](https://github.com/dmlc/xgboost/pull/5744). 57 | 58 | If you are using for some reason older versions, you may need to specify path to the XGBoost libraries while building, e.g.: 59 | 60 | ``` 61 | swift build -Xcc -I/usr/local/include -Xlinker -L/usr/local/lib 62 | ``` 63 | 64 | or create pkg-config file manualy. Example of it for `macOS 10.15` and `XGBoost 1.1.0` is 65 | 66 | ``` 67 | prefix=/usr/local/Cellar/xgboost/1.1.0 68 | exec_prefix=${prefix}/bin 69 | libdir=${prefix}/lib 70 | includedir=${prefix}/include 71 | 72 | Name: xgboost 73 | Description: XGBoost machine learning libarary. 74 | Version: 1.1.0 75 | Cflags: -I${includedir} 76 | Libs: -L${libdir} -lxgboost 77 | ``` 78 | 79 | and needs to be placed at `/usr/local/lib/pkgconfig/xgboost.pc` 80 | 81 | ### Package 82 | 83 | Add a dependency in your your `Package.swift` 84 | 85 | ```swift 86 | .package(url: "https://github.com/kongzii/SwiftXGBoost.git", from: "0.0.0"), 87 | ``` 88 | 89 | Import Swifty XGBoost 90 | 91 | ```swift 92 | import XGBoost 93 | ``` 94 | 95 | or directly C library 96 | 97 | ```swift 98 | import CXGBoost 99 | ``` 100 | 101 | both `Booster` and `DMatrix` classes are exposing pointers to the underlying C, 102 | so you can utilize C-API directly for more advanced usage. 103 | 104 | As the library is still evolving, there can be incompatible changes between updates, 105 | the releases before version 1.0.0 doesn't follow [Semantic Versioning](https://semver.org/). 106 | Please use the exact version if you do not want to worry about updating your packages. 107 | 108 | ```swift 109 | .package(url: "https://github.com/kongzii/SwiftXGBoost.git", .exact("0.1.0")), 110 | ``` 111 | 112 | ## Python compatibility 113 | 114 | DMatrix can be created from numpy array just like in Python 115 | 116 | ```swift 117 | let pandas = Python.import("pandas") 118 | let dataFrame = pandas.read_csv("data.csv") 119 | let data = try DMatrix( 120 | name: "training", 121 | from: dataFrame.values 122 | ) 123 | ``` 124 | 125 | and the swift array can be converted back to numpy 126 | 127 | ```swift 128 | let predicted = try booster.predict( 129 | from: validationData 130 | ) 131 | 132 | let compare = pandas.DataFrame([ 133 | "Label lower bound": yLowerBound[validIndex], 134 | "Label upper bound": yUpperBound[validIndex], 135 | "Prediced": predicted.makeNumpyArray(), 136 | ]) 137 | 138 | print(compare) 139 | ``` 140 | 141 | This is possible thanks to the [PythonKit](https://github.com/pvieito/PythonKit.git). 142 | For more detailed usage and workarounds for known issues, check out [examples](https://github.com/kongzii/SwiftXGBoost/tree/master/Examples). 143 | 144 | ## TensorFlow compability 145 | 146 | [Swift4TensorFlow](https://github.com/tensorflow/swift) is a great project from Google. 147 | If you are using one of the S4TF swift toolchains, you can combine its power directly with XGBoost. 148 | 149 | ```swift 150 | let tensor = Tensor(shape: TensorShape([2, 3]), scalars: [1, 2, 3, 4, 5, 6]) 151 | let data = try DMatrix(name: "training", from: tensor) 152 | ``` 153 | 154 | ### Note 155 | 156 | [Swift4TensorFlow](https://github.com/tensorflow/swift) toolchains ships with preinstalled [PythonKit](https://github.com/pvieito/PythonKit.git) and you may run into a problem when using package with extra [PythonKit](https://github.com/pvieito/PythonKit.git) dependency. If so, please just add package version with `-tensorflow` suffix, where [PythonKit](https://github.com/pvieito/PythonKit.git) dependency is removed. 157 | 158 | ```swift 159 | .package(url: "https://github.com/kongzii/SwiftXGBoost.git", .exact("0.7.0-tensorflow")), 160 | ``` 161 | 162 | This bug is known and hopefully will be resolved soon. 163 | 164 | ## Examples 165 | 166 | More examples can be found in [Examples directory](https://github.com/kongzii/SwiftXGBoost/tree/master/Examples) 167 | and run inside docker 168 | 169 | ``` 170 | docker-compose run swiftxgboost swift run exampleName 171 | ``` 172 | 173 | or on host 174 | 175 | ``` 176 | swift run exampleName 177 | ``` 178 | 179 | ### Basic functionality 180 | 181 | ```swift 182 | import XGBoost 183 | 184 | // Register your own callback function for log(info) messages 185 | try XGBoost.registerLogCallback { 186 | print("Swifty log:", String(cString: $0!)) 187 | } 188 | 189 | // Create some random features and labels 190 | let randomArray = (0 ..< 1000).map { _ in Float.random(in: 0 ..< 2) } 191 | let labels = (0 ..< 100).map { _ in Float([0, 1].randomElement()!) } 192 | 193 | // Initialize data, DMatrixHandle in the background 194 | let data = try DMatrix( 195 | name: "data", 196 | from: randomArray, 197 | shape: Shape(100, 10), 198 | label: labels, 199 | threads: 1 200 | ) 201 | 202 | // Slice array into train and test 203 | let train = try data.slice(indexes: 0 ..< 90, newName: "train") 204 | let test = try data.slice(indexes: 90 ..< 100, newName: "test") 205 | 206 | // Parameters for Booster, check https://xgboost.readthedocs.io/en/latest/parameter.html 207 | let parameters = [ 208 | Parameter("verbosity", "2"), 209 | Parameter("seed", "0"), 210 | ] 211 | 212 | // Create Booster model, `with` data will be cached 213 | let booster = try Booster( 214 | with: [train, test], 215 | parameters: parameters 216 | ) 217 | 218 | // Train booster, optionally provide callback functions called before and after each iteration 219 | try booster.train( 220 | iterations: 10, 221 | trainingData: train, 222 | evaluationData: [train, test] 223 | ) 224 | 225 | // Predict from test data 226 | let predictions = try booster.predict(from: test) 227 | 228 | // Save 229 | try booster.save(to: "model.xgboost") 230 | ``` 231 | 232 | ## Development 233 | 234 | ### Documentation 235 | 236 | [Jazzy](https://github.com/realm/jazzy) is used for the generation of documentation. 237 | 238 | You can generate documentation locally using 239 | 240 | ``` 241 | make documentation 242 | ``` 243 | 244 | Github pages will be updated automatically when merged into master. 245 | 246 | ### Tests 247 | 248 | Where possible, Swift implementation is tested against reference implementation in Python via PythonKit. For example, test of `score` method in `scoreEmptyFeatureMapTest` 249 | 250 | ```swift 251 | let pyFMap = [String: Int](pyXgboost.get_score( 252 | fmap: "", importance_type: "weight"))! 253 | let (fMap, _) = try booster.score(featureMap: "", importance: .weight) 254 | 255 | XCTAssertEqual(fMap, pyFMap) 256 | ``` 257 | 258 | #### Run locally 259 | 260 | On ubuntu using docker 261 | 262 | ``` 263 | docker-compose run test 264 | ``` 265 | 266 | On host 267 | 268 | ``` 269 | swift test 270 | ``` 271 | 272 | ### Code format 273 | 274 | [SwiftFormat](https://github.com/nicklockwood/SwiftFormat) is used for code formatting. 275 | 276 | ``` 277 | make format 278 | ``` 279 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Sources/XGBoost/CrossValidation.swift: -------------------------------------------------------------------------------- 1 | import CXGBoost 2 | import Foundation 3 | 4 | /// Auxiliary class to hold one fold of CrossValidation. 5 | public class CVPack { 6 | /// CVPack identifier. 7 | public let id: String 8 | 9 | /// DMatrixes used for training and testing. 10 | public let train, test: DMatrix 11 | 12 | /// Booster class. 13 | public let booster: Booster 14 | 15 | /// Initialization of CVPack. Booster will get attribute `cvpack_id` with value of `id`. 16 | /// 17 | /// - Parameter id: Identifier for this fold. 18 | /// - Parameter train: Training DMatrix. 19 | /// - Parameter test: Testing DMatrix. 20 | /// - Parameter parameters: Parameters pased to Booster. 21 | public init( 22 | id: String, 23 | train: DMatrix, 24 | test: DMatrix, 25 | parameters: [Parameter] 26 | ) throws { 27 | self.id = id 28 | self.train = train 29 | self.test = test 30 | booster = try Booster( 31 | with: [train, test], 32 | parameters: parameters 33 | ) 34 | 35 | try booster.set(attribute: "cvpack_id", value: self.id) 36 | } 37 | } 38 | 39 | /// Given group row boundaries, convert ground indexes to row indexes 40 | /// 41 | /// - Parameter groups: Array of groups. 42 | /// - Parameter boundaries: Rows index limits of each group. 43 | /// - Returns: Row in group. 44 | func groupsToRows( 45 | groups: [Int], 46 | boundaries: [UInt32] 47 | ) -> [UInt32] { 48 | groups.map { g in 49 | boundaries[g] ..< boundaries[g + 1] 50 | }.flatMap { $0 } 51 | } 52 | 53 | /// Make folds for cross-validation maintaining groups. 54 | /// 55 | /// - Parameter data: DMatrix data. 56 | /// - Parameter splits: Number of folds to create. 57 | /// - Parameter parameters: Parameters passed to Booster in CVPack. 58 | /// - Parameter shuffle: If true, rows in folds will be shuffled. 59 | /// - Returns: Array of CVPacks. 60 | func makeGroupNFold( 61 | data: Data, 62 | splits: Int, 63 | parameters: [Parameter], 64 | shuffle: Bool 65 | ) throws -> [CVPack] { 66 | let groupBoundaries = try data.get(field: .groupPtr) 67 | let groupSizes = groupBoundaries.diff() 68 | 69 | var idx = Array(0 ..< groupSizes.count) 70 | 71 | if shuffle { 72 | idx.shuffle() 73 | } 74 | 75 | let outGroupIdSet = idx.chunked(into: splits) 76 | let inGroupIdSet = (0 ..< splits).map { k in 77 | (0 ..< splits).filter { i in i != k }.flatMap { i in 78 | outGroupIdSet[i] 79 | } 80 | } 81 | 82 | let outIdSet = outGroupIdSet.map { 83 | groupsToRows(groups: $0, boundaries: groupBoundaries).map { Int32($0) } 84 | } 85 | let inIdSet = inGroupIdSet.map { 86 | groupsToRows(groups: $0, boundaries: groupBoundaries).map { Int32($0) } 87 | } 88 | 89 | var packs = [CVPack]() 90 | 91 | for k in 0 ..< splits { 92 | let train = try data.slice( 93 | indexes: inIdSet[k], 94 | allowGroups: true 95 | ) 96 | let trainGroupSizes = inGroupIdSet[k].map { groupSizes[Int($0)] } 97 | try train.set( 98 | field: .group, 99 | values: trainGroupSizes 100 | ) 101 | let test = try data.slice( 102 | indexes: outIdSet[k], 103 | allowGroups: true 104 | ) 105 | let testGroupSizes = outGroupIdSet[k].map { groupSizes[Int($0)] } 106 | try test.set( 107 | field: .group, 108 | values: testGroupSizes 109 | ) 110 | packs.append( 111 | try CVPack( 112 | id: "\(k)", 113 | train: train, 114 | test: test, 115 | parameters: parameters 116 | ) 117 | ) 118 | } 119 | 120 | return packs 121 | } 122 | 123 | /// Make folds for cross-validation. 124 | /// 125 | /// - Parameter data: DMatrix data. 126 | /// - Parameter splits: Number of folds to create. 127 | /// - Parameter parameters: Parameters passed to Booster in CVPack. 128 | /// - Parameter shuffle: If true, rows in folds will be shuffled. 129 | /// - Returns: Array of CVPacks. 130 | func makeNFold( 131 | data: Data, 132 | splits: Int, 133 | parameters: [Parameter], 134 | shuffle: Bool 135 | ) throws -> [CVPack] { 136 | if try data.get(field: .groupPtr).count > 1 { 137 | return try makeGroupNFold( 138 | data: data, 139 | splits: splits, 140 | parameters: parameters, 141 | shuffle: shuffle 142 | ) 143 | } 144 | 145 | var idx = try Array(0 ..< Int32(data.rowCount())) 146 | 147 | if shuffle { 148 | idx.shuffle() 149 | } 150 | 151 | let outIdSet = idx.chunked(into: splits) 152 | let inIdSet = (0 ..< splits).map { k in 153 | (0 ..< splits).filter { i in i != k }.flatMap { i in 154 | outIdSet[i] 155 | } 156 | } 157 | 158 | var packs = [CVPack]() 159 | 160 | for k in 0 ..< splits { 161 | let train = try data.slice( 162 | indexes: inIdSet[k], 163 | newName: "\(data.name)-train" 164 | ) 165 | let test = try data.slice( 166 | indexes: outIdSet[k], 167 | newName: "\(data.name)-test" 168 | ) 169 | 170 | packs.append( 171 | try CVPack( 172 | id: "\(k)", 173 | train: train, 174 | test: test, 175 | parameters: parameters 176 | ) 177 | ) 178 | } 179 | 180 | return packs 181 | } 182 | 183 | /// - Parameter results: Array of Evaluation. 184 | /// - Returns: Metrics for cross-validation. 185 | func aggregateCrossValidationResults( 186 | results: [Evaluation] 187 | ) throws -> [(metricName: String, mean: Float, std: Float)] { 188 | var crossDict = CVEvaluation() 189 | 190 | for result in results { 191 | for (dataName, dataResults) in result { 192 | for (metricName, metricValue) in dataResults { 193 | crossDict["\(dataName)-\(metricName)", or: []].append(Float(metricValue)!) 194 | } 195 | } 196 | } 197 | 198 | var results = [(metricName: String, mean: Float, std: Float)]() 199 | 200 | for (metricName, values) in crossDict { 201 | let mean = values.mean() 202 | results.append(( 203 | metricName: metricName, 204 | mean: mean, 205 | std: values.std(mean: mean) 206 | )) 207 | } 208 | 209 | return results 210 | } 211 | 212 | /// Typealias for dictionary of cross-validation evaluations. 213 | public typealias CVEvaluation = [String: [Float]] 214 | 215 | /// Typealias for function called after CV iteration. 216 | public typealias AfterCVIteration = (Int, CVEvaluation, Bool) throws -> AfterIterationOutput 217 | 218 | /// Default function called after CV iteration. 219 | public let DefaultAfterCVIteration: AfterCVIteration = { _, _, _ in .next } 220 | 221 | /// Training with cross-validation. 222 | /// 223 | /// - Parameter folds: Cross-validation folds to train. 224 | /// - Parameter iterations: Number training of iterations. 225 | /// - Parameter earlyStopping: Early stopping used for cross-validation results. 226 | /// - Parameter objectiveFunction: Custom objective function passed to the booster.train. 227 | /// - Parameter evaluationFunction: Custom evaluation function passed to the booster.train. 228 | /// - Parameter beforeIteration: Before iteration function passed to the booster.train. 229 | /// - Parameter callbacks: Callbacks passed to the booster.train. 230 | /// - Parameter afterIteration: After iteration function passed to the booster.train. 231 | /// - Parameter afterCVIteration: Function called after one cv iteration, e.g. after one training iteration of every fold. 232 | public func crossValidationTraining( 233 | folds: [CVPack], 234 | iterations: Int, 235 | earlyStopping: EarlyStopping? = nil, 236 | objectiveFunction: ObjectiveFunction? = nil, 237 | evaluationFunction: EvaluationFunction? = nil, 238 | beforeIteration: BeforeIteration = DefaultBeforeIteration, 239 | callbacks: [Callback] = [], 240 | afterIteration: AfterIteration = DefaultAfterIteration, 241 | afterCVIteration: AfterCVIteration = DefaultAfterCVIteration 242 | ) throws -> (results: CVEvaluation, folds: [CVPack]) { 243 | var cvResults = CVEvaluation() 244 | 245 | training: for iteration in 0 ..< iterations { 246 | var stopped = 0 247 | var iterationEvaluations = [Evaluation]() 248 | 249 | for (index, fold) in folds.enumerated() { 250 | try fold.booster.train( 251 | iterations: iteration + 1, 252 | startIteration: iteration, 253 | trainingData: fold.train, 254 | objectiveFunction: objectiveFunction, 255 | evaluationData: [fold.train, fold.test], 256 | evaluationFunction: evaluationFunction, 257 | beforeIteration: beforeIteration, 258 | callbacks: callbacks 259 | ) { booster, iteration, evaluation, outputs in 260 | if outputs.willStop { 261 | stopped += 1 262 | } 263 | 264 | iterationEvaluations.append(evaluation!) 265 | return try afterIteration(booster, iteration, evaluation, outputs) 266 | } 267 | } 268 | 269 | let aggregatedEvaluation = try aggregateCrossValidationResults( 270 | results: iterationEvaluations 271 | ) 272 | for (metricName, mean, std) in aggregatedEvaluation { 273 | cvResults[metricName + "-mean", or: []].append(mean) 274 | cvResults[metricName + "-std", or: []].append(std) 275 | } 276 | 277 | let earlyStop: Bool = try { 278 | if let earlyStopping = earlyStopping { 279 | return try earlyStopping.call( 280 | iteration: iteration, 281 | evaluation: cvResults 282 | ) == .stop 283 | } 284 | 285 | return false 286 | }() 287 | 288 | let willStop = earlyStop || stopped == folds.count 289 | 290 | switch try afterCVIteration( 291 | iteration, 292 | cvResults, 293 | willStop 294 | ) { 295 | case .stop: 296 | break training 297 | case .next: 298 | break 299 | } 300 | 301 | if willStop { 302 | break training 303 | } 304 | } 305 | 306 | if let earlyStopping = earlyStopping { 307 | for (key, values) in cvResults { 308 | cvResults[key] = Array(values[0 ... earlyStopping.state.bestIteration]) 309 | } 310 | } 311 | 312 | return (cvResults, folds) 313 | } 314 | 315 | /// Training with cross-validation. 316 | /// 317 | /// - Parameter data: Data to split. 318 | /// - Parameter splits: Number of cvfoolds that will be created. 319 | /// - Parameter iterations: Number training of iterations. 320 | /// - Parameter parameters: Parameters passed to the Booster initialized in CVPack. 321 | /// - Parameter earlyStopping: Early stopping used for cross-validation results. 322 | /// - Parameter objectiveFunction: Custom objective function passed to the booster.train. 323 | /// - Parameter evaluationFunction: Custom evaluation function passed to the booster.train. 324 | /// - Parameter beforeIteration: Before iteration function passed to the booster.train. 325 | /// - Parameter callbacks: Callbacks passed to the booster.train. 326 | /// - Parameter afterIteration: After iteration function passed to the booster.train. 327 | /// - Parameter afterCVIteration: Function called after one cv iteration, e.g. after one training iteration of every fold. 328 | public func crossValidationTraining( 329 | data: Data, 330 | splits: Int, 331 | iterations: Int, 332 | parameters: [Parameter], 333 | earlyStopping: EarlyStopping? = nil, 334 | objectiveFunction: ObjectiveFunction? = nil, 335 | evaluationFunction: EvaluationFunction? = nil, 336 | shuffle: Bool = true, 337 | beforeIteration: BeforeIteration = DefaultBeforeIteration, 338 | callbacks: [Callback] = [], 339 | afterIteration: AfterIteration = DefaultAfterIteration, 340 | afterCVIteration: AfterCVIteration = DefaultAfterCVIteration 341 | ) throws -> (results: CVEvaluation, folds: [CVPack]) { 342 | try crossValidationTraining( 343 | folds: try makeNFold( 344 | data: data, 345 | splits: splits, 346 | parameters: parameters, 347 | shuffle: shuffle 348 | ), 349 | iterations: iterations, 350 | earlyStopping: earlyStopping, 351 | objectiveFunction: objectiveFunction, 352 | evaluationFunction: evaluationFunction, 353 | beforeIteration: beforeIteration, 354 | callbacks: callbacks, 355 | afterIteration: afterIteration, 356 | afterCVIteration: afterCVIteration 357 | ) 358 | } 359 | -------------------------------------------------------------------------------- /Tests/XGBoostTests/BoosterTests.swift: -------------------------------------------------------------------------------- 1 | import PythonKit 2 | import XCTest 3 | 4 | @testable import XGBoost 5 | 6 | private let PJSON = Python.import("json") 7 | 8 | final class BoosterTests: XCTestCase { 9 | func testAttribute() throws { 10 | let booster = try Booster() 11 | try booster.set(attribute: "testName", value: "testValue") 12 | XCTAssertEqual(try booster.attribute(name: "testName"), "testValue") 13 | XCTAssertNil(try booster.attribute(name: "unknownName")) 14 | } 15 | 16 | func testAttributes() throws { 17 | let booster = try Booster() 18 | try booster.set(attribute: "attribute1", value: "value1") 19 | try booster.set(attribute: "attribute2", value: "value2") 20 | XCTAssertEqual(try booster.attributes(), ["attribute1": "value1", "attribute2": "value2"]) 21 | } 22 | 23 | func testJsonDumped() throws { 24 | let randomArray = (0 ..< 10).map { _ in Float.random(in: 0 ..< 2) } 25 | let label = (0 ..< 10).map { _ in Float([0, 1].randomElement()!) } 26 | let data = try DMatrix( 27 | name: "data", 28 | from: randomArray, 29 | shape: Shape(10, 1), 30 | label: label, 31 | threads: 1 32 | ) 33 | let booster = try Booster(with: [data]) 34 | try booster.train( 35 | iterations: 5, 36 | trainingData: data 37 | ) 38 | 39 | let pyBooster = try python(booster: booster) 40 | 41 | let temporaryDumpFile = temporaryFile() 42 | let temporaryFeatureMapFile = temporaryFile() 43 | try booster.features!.saveFeatureMap(to: temporaryFeatureMapFile) 44 | 45 | pyBooster.dump_model(fout: temporaryDumpFile, fmap: temporaryFeatureMapFile, dump_format: "json") 46 | 47 | let json = try booster.dumped(format: .json) 48 | let pyJson = try String(contentsOfFile: temporaryDumpFile) 49 | 50 | let jsonObject = PJSON.loads(json) 51 | let pyJsonObject = PJSON.loads(pyJson) 52 | 53 | XCTAssertTrue(jsonObject == pyJsonObject) 54 | } 55 | 56 | func testTextDumped() throws { 57 | let randomArray = (0 ..< 10).map { _ in Float.random(in: 0 ..< 2) } 58 | let label = (0 ..< 10).map { _ in Float([0, 1].randomElement()!) } 59 | let data = try DMatrix( 60 | name: "data", 61 | from: randomArray, 62 | shape: Shape(10, 1), 63 | label: label, 64 | threads: 1 65 | ) 66 | let booster = try Booster(with: [data]) 67 | try booster.train( 68 | iterations: 5, 69 | trainingData: data 70 | ) 71 | 72 | let pyBooster = try python(booster: booster) 73 | 74 | let temporaryDumpFile = temporaryFile() 75 | let temporaryFeatureMapFile = temporaryFile() 76 | 77 | let features = [Feature(name: "x", type: .quantitative)] 78 | try features.saveFeatureMap(to: temporaryFeatureMapFile) 79 | 80 | pyBooster.dump_model( 81 | fout: temporaryDumpFile, fmap: temporaryFeatureMapFile, 82 | with_stats: true, dump_format: "text" 83 | ) 84 | 85 | let textFeatures = try booster.dumped( 86 | features: features, withStatistics: true, format: .text 87 | ) 88 | let textFeatureMap = try booster.dumped( 89 | featureMap: temporaryFeatureMapFile, withStatistics: true, format: .text 90 | ) 91 | let pyText = try String(contentsOfFile: temporaryDumpFile) 92 | 93 | XCTAssertEqual(textFeatures, pyText) 94 | XCTAssertEqual(textFeatureMap, pyText) 95 | } 96 | 97 | func testDotDumped() throws { 98 | let randomArray = (0 ..< 10).map { _ in Float.random(in: 0 ..< 2) } 99 | let label = (0 ..< 10).map { _ in Float([0, 1].randomElement()!) } 100 | let data = try DMatrix( 101 | name: "data", 102 | from: randomArray, 103 | shape: Shape(10, 1), 104 | label: label, 105 | threads: 1 106 | ) 107 | let booster = try Booster(with: [data]) 108 | try booster.train( 109 | iterations: 5, 110 | trainingData: data 111 | ) 112 | 113 | let pyBooster = try python(booster: booster) 114 | 115 | let dot = try booster.rawDumped(format: .dot) 116 | let pyDot = pyBooster.get_dump(dump_format: "dot").map { String($0)! } 117 | XCTAssertEqual(dot, pyDot) 118 | 119 | let temporaryFeatureMapFile = temporaryFile() 120 | try [Feature(name: "x", type: .quantitative)].saveFeatureMap(to: temporaryFeatureMapFile) 121 | 122 | let formatDot = try booster.dumped(featureMap: temporaryFeatureMapFile, format: .dot) 123 | let temporaryDumpFile = temporaryFile() 124 | pyBooster.dump_model(temporaryDumpFile, fmap: temporaryFeatureMapFile, dump_format: "dot") 125 | let pyFormatDot = try String(contentsOfFile: temporaryDumpFile) 126 | XCTAssertEqual(formatDot, pyFormatDot) 127 | } 128 | 129 | func testScoreEmptyFeatureMap() throws { 130 | let randomArray = (0 ..< 50).map { _ in Float.random(in: 0 ..< 2) } 131 | let label = (0 ..< 10).map { _ in Float([0, 1].randomElement()!) } 132 | let data = try DMatrix( 133 | name: "data", 134 | from: randomArray, 135 | shape: Shape(10, 5), 136 | label: label, 137 | threads: 1 138 | ) 139 | let booster = try Booster(with: [data]) 140 | try booster.train( 141 | iterations: 10, 142 | trainingData: data 143 | ) 144 | 145 | let pyBooster = try python(booster: booster) 146 | 147 | let pyWeightMap = [String: Int](pyBooster.get_score( 148 | fmap: "", importance_type: "weight" 149 | ))! 150 | let (weightMap, _) = try booster.score(featureMap: "", importance: .weight) 151 | XCTAssertEqual(weightMap, pyWeightMap) 152 | 153 | let pyGainMap = [String: Float](pyBooster.get_score( 154 | fmap: "", importance_type: "gain" 155 | ))! 156 | let (_, gainMap) = try booster.score(featureMap: "", importance: .gain) 157 | assertEqual(gainMap!, pyGainMap, accuracy: 1e-6) 158 | 159 | let pyTotalGainMap = [String: Float](pyBooster.get_score( 160 | fmap: "", importance_type: "total_gain" 161 | ))! 162 | let (_, totalGainMap) = try booster.score(featureMap: "", importance: .totalGain) 163 | assertEqual(totalGainMap!, pyTotalGainMap, accuracy: 1e-6) 164 | 165 | let pyCoverMap = [String: Float](pyBooster.get_score( 166 | fmap: "", importance_type: "cover" 167 | ))! 168 | let (_, coverMap) = try booster.score(featureMap: "", importance: .cover) 169 | assertEqual(coverMap!, pyCoverMap, accuracy: 1e-6) 170 | 171 | let pyTotalCoverMap = [String: Float](pyBooster.get_score( 172 | fmap: "", importance_type: "total_cover" 173 | ))! 174 | let (_, totalCoverMap) = try booster.score(featureMap: "", importance: .totalCover) 175 | assertEqual(totalCoverMap!, pyTotalCoverMap, accuracy: 1e-6) 176 | } 177 | 178 | func testScoreWithFeatureMap() throws { 179 | let randomArray = (0 ..< 50).map { _ in Float.random(in: 0 ..< 2) } 180 | let features = (0 ..< 5).map { Feature(name: "Feature-\($0)", type: .quantitative) } 181 | let label = (0 ..< 10).map { _ in Float([0, 1].randomElement()!) } 182 | let data = try DMatrix( 183 | name: "data", 184 | from: randomArray, 185 | shape: Shape(10, 5), 186 | features: features, 187 | label: label, 188 | threads: 1 189 | ) 190 | let booster = try Booster(with: [data]) 191 | try booster.train( 192 | iterations: 10, 193 | trainingData: data 194 | ) 195 | 196 | let temporaryNamesFile = temporaryFile() 197 | 198 | try features.saveFeatureMap(to: temporaryNamesFile) 199 | 200 | let pyBooster = try python(booster: booster) 201 | 202 | let pyWeightMap = [String: Int](pyBooster.get_score( 203 | fmap: temporaryNamesFile, importance_type: "weight" 204 | ))! 205 | let (weightMap, _) = try booster.score( 206 | featureMap: temporaryNamesFile, importance: .weight 207 | ) 208 | XCTAssertEqual(weightMap, pyWeightMap) 209 | 210 | let pyGainMap = [String: Float](pyBooster.get_score( 211 | fmap: temporaryNamesFile, importance_type: "gain" 212 | ))! 213 | let (_, gainMap) = try booster.score( 214 | featureMap: temporaryNamesFile, importance: .gain 215 | ) 216 | assertEqual(gainMap!, pyGainMap, accuracy: 1e-6) 217 | 218 | let pyTotalGainMap = [String: Float](pyBooster.get_score( 219 | fmap: temporaryNamesFile, importance_type: "total_gain" 220 | ))! 221 | let (_, totalGainMap) = try booster.score( 222 | featureMap: temporaryNamesFile, importance: .totalGain 223 | ) 224 | assertEqual(totalGainMap!, pyTotalGainMap, accuracy: 1e-6) 225 | 226 | let pyCoverMap = [String: Float](pyBooster.get_score( 227 | fmap: temporaryNamesFile, importance_type: "cover" 228 | ))! 229 | let (_, coverMap) = try booster.score( 230 | featureMap: temporaryNamesFile, importance: .cover 231 | ) 232 | assertEqual(coverMap!, pyCoverMap, accuracy: 1e-6) 233 | 234 | let pyTotalCoverMap = [String: Float](pyBooster.get_score( 235 | fmap: temporaryNamesFile, importance_type: "total_cover" 236 | ))! 237 | let (_, totalCoverMap) = try booster.score( 238 | featureMap: temporaryNamesFile, importance: .totalCover 239 | ) 240 | assertEqual(totalCoverMap!, pyTotalCoverMap, accuracy: 1e-6) 241 | } 242 | 243 | func testSystemLibraryVersion() throws { 244 | let version = XGBoost.systemLibraryVersion 245 | 246 | XCTAssertGreaterThanOrEqual(version.major, 1) 247 | XCTAssertGreaterThanOrEqual(version.minor, 1) 248 | XCTAssertGreaterThanOrEqual(version.patch, 0) 249 | } 250 | 251 | func testConfig() throws { 252 | let json = Python.import("json") 253 | 254 | let booster = try randomTrainedBooster() 255 | let pyBooster = try python(booster: booster) 256 | 257 | let config = json.loads(try booster.config()) 258 | let pyConfig = json.loads(pyBooster.save_config()) 259 | 260 | config["version"] = Python.None 261 | pyConfig["version"] = Python.None 262 | 263 | XCTAssertEqual(String(json.dumps(config))!, String(json.dumps(pyConfig))!) 264 | } 265 | 266 | func testRawDumped() throws { 267 | let booster = try randomTrainedBooster() 268 | let pyBooster = try python(booster: booster) 269 | 270 | let rawDump = try booster.rawDumped() 271 | let pyRawDump = [String](pyBooster.get_dump())! 272 | XCTAssertEqual(rawDump, pyRawDump) 273 | 274 | let rawDumpStatistics = try booster.rawDumped(withStatistics: true) 275 | let pyRawDumpStatistics = [String](pyBooster.get_dump(with_stats: true))! 276 | XCTAssertEqual(rawDumpStatistics, pyRawDumpStatistics) 277 | } 278 | 279 | func testLoadModel() throws { 280 | let booster = try randomTrainedBooster() 281 | try booster.set(attribute: "hello", value: "world") 282 | let modelFile = temporaryFile() 283 | 284 | try booster.save(to: modelFile) 285 | let boosterLoaded = try Booster(from: modelFile) 286 | 287 | XCTAssertEqual(try booster.attribute(name: "hello"), try boosterLoaded.attribute(name: "hello")) 288 | } 289 | 290 | func testLoadConfig() throws { 291 | let booster1 = try randomTrainedBooster() 292 | try booster1.set(parameter: "eta", value: "555") 293 | 294 | let modelFile = temporaryFile() 295 | try booster1.save(to: modelFile) 296 | 297 | let booster2 = try randomTrainedBooster() 298 | try booster2.set(parameter: "eta", value: "666") 299 | let config = try booster2.config() 300 | 301 | let booster3 = try Booster(from: modelFile, config: config) 302 | XCTAssertEqual(try booster3.config(), try booster2.config()) 303 | } 304 | 305 | func testInitializeWithType() throws { 306 | XCTAssertEqual(try Booster( 307 | parameters: [Parameter("booster", "dart")] 308 | ).type, BoosterType(rawValue: "dart")) 309 | XCTAssertThrowsError(try Booster( 310 | parameters: [Parameter("booster", "icecream")] 311 | )) 312 | } 313 | 314 | func testBoosterPredictOptionMasks() throws { 315 | let booster = try randomTrainedBooster() 316 | let pyBooster = try python(booster: booster) 317 | 318 | let test = try randomDMatrix() 319 | let pyTest = try python(dmatrix: test) 320 | 321 | let outputMargin = try booster.predict(from: test, outputMargin: true) 322 | let pyoutputMargin = pyBooster.predict(data: pyTest, output_margin: true) 323 | XCTAssertEqual(try outputMargin.data(), try pyoutputMargin.data()) 324 | XCTAssertEqual(try outputMargin.dataShape(), try pyoutputMargin.dataShape()) 325 | 326 | let predictionleaf = try booster.predict(from: test, predictionLeaf: true) 327 | let pypredictionleaf = pyBooster.predict(data: pyTest, pred_leaf: true) 328 | XCTAssertEqual(try predictionleaf.data(), try pypredictionleaf.data()) 329 | XCTAssertEqual(try predictionleaf.dataShape(), try pypredictionleaf.dataShape()) 330 | 331 | let predictionContributions = try booster.predict(from: test, predictionContributions: true) 332 | let pypredictionContributions = pyBooster.predict(data: pyTest, pred_contribs: true) 333 | XCTAssertEqual(try predictionContributions.data(), try pypredictionContributions.data()) 334 | XCTAssertEqual(try predictionContributions.dataShape(), try pypredictionContributions.dataShape()) 335 | 336 | let approximateContributions = try booster.predict(from: test, approximateContributions: true) 337 | let pyapproximateContributions = pyBooster.predict(data: pyTest, approx_contribs: true) 338 | XCTAssertEqual(try approximateContributions.data(), try pyapproximateContributions.data()) 339 | XCTAssertEqual(try approximateContributions.dataShape(), try pyapproximateContributions.dataShape()) 340 | 341 | let predictionInteractions = try booster.predict(from: test, predictionInteractions: true) 342 | let pypredictionInteractions = pyBooster.predict(data: pyTest, pred_interactions: true) 343 | XCTAssertEqual(try predictionInteractions.data(), try pypredictionInteractions.data()) 344 | XCTAssertEqual(try predictionInteractions.dataShape(), try pypredictionInteractions.dataShape()) 345 | } 346 | 347 | func testPredictOne() throws { 348 | let booster = try randomTrainedBooster() 349 | let pyBooster = try python(booster: booster) 350 | 351 | let features: [Float] = [1, 2, 3, 4, 5] 352 | let dmatrix = try DMatrix(name: "test", from: features, shape: Shape(1, 5)) 353 | 354 | let pydmatrix = try python(dmatrix: dmatrix) 355 | 356 | let predicted = try booster.predict(features: features) 357 | let pypredicted = pyBooster.predict(pydmatrix) 358 | 359 | XCTAssertEqual(predicted, Float(pypredicted[0])!) 360 | } 361 | 362 | static var allTests = [ 363 | ("testAttribute", testAttribute), 364 | ("testAttributes", testAttributes), 365 | ("testJsonDumped", testJsonDumped), 366 | ("testTextDumped", testTextDumped), 367 | ("testDotDumped", testDotDumped), 368 | ("testScoreEmptyFeatureMap", testScoreEmptyFeatureMap), 369 | ("testScoreWithFeatureMap", testScoreWithFeatureMap), 370 | ("testSystemLibraryVersion", testSystemLibraryVersion), 371 | ("testConfig", testConfig), 372 | ("testRawDumped", testRawDumped), 373 | ("testLoadModel", testLoadModel), 374 | ("testLoadConfig", testLoadConfig), 375 | ("testInitializeWithType", testInitializeWithType), 376 | ("testBoosterPredictOptionMasks", testBoosterPredictOptionMasks), 377 | ("testPredictOne", testPredictOne), 378 | ] 379 | } 380 | -------------------------------------------------------------------------------- /Sources/XGBoost/Train.swift: -------------------------------------------------------------------------------- 1 | import CXGBoost 2 | import Foundation 3 | 4 | /// Dictionary for evaluation in form [data_name: [metric_name: value]] 5 | public typealias Evaluation = [String: [String: String]] 6 | 7 | /// Specify if should be executed at beginning of iteration or at the end. 8 | public enum LoopPosition { 9 | case before, after 10 | } 11 | 12 | /// Protocol for classes and structs that can be passed to traning in callbacks array. 13 | public protocol Callback { 14 | // Name of the callback 15 | var name: String { get } 16 | 17 | /// Time when to execute the callback in training loop. 18 | var execute: [LoopPosition] { get } 19 | 20 | /// - Parameter booster: Optional booster. 21 | /// - Parameter iteration: Current iteration. 22 | /// - Parameter evaluation: Optional dictionary with evaluations. 23 | func call( 24 | booster: Booster?, 25 | iteration: Int, 26 | evaluation: Evaluation? 27 | ) throws -> AfterIterationOutput 28 | } 29 | 30 | /// Class used for early stopping feature. 31 | public class EarlyStopping: Callback { 32 | /// Name of the callback. 33 | public var name: String = "EarlyStopping" 34 | 35 | /// Position of the execution. 36 | public var execute: [LoopPosition] = [.after] 37 | 38 | /// Typealais for tuple holding current state. 39 | public typealias State = ( 40 | maximizeScore: Bool, 41 | bestIteration: Int, 42 | bestScore: Float, 43 | bestMsg: String 44 | ) 45 | 46 | /// Current state of training. 47 | public var state: State 48 | 49 | /// Name of watched DMatrix. 50 | public var dataName: String 51 | 52 | /// Name of watched metric. 53 | public var metricName: String 54 | 55 | /// Number of stopping rounds. 56 | public var stoppingRounds: Int 57 | 58 | /// If true, statistics will be printed. 59 | public var verbose: Bool 60 | 61 | /// - Parameter bestIteration: Number of booster best iteration. 62 | /// - Parameter bestEvaluation: Booster best evaluation. 63 | static func formatMessage( 64 | bestIteration: Int, 65 | bestEvaluation: Any 66 | ) -> String { 67 | "Best iteration: \(bestIteration), best evaluation: \(bestEvaluation)." 68 | } 69 | 70 | /// Decides whenever metric should be maximized or minimized. 71 | /// 72 | /// - Parameter maximize: Force maximization. 73 | /// - Parameter metricName: Metric name to check, if it should be maximized. 74 | static func shouldMaximize( 75 | maximize: Bool, 76 | metricName: String 77 | ) -> Bool { 78 | var maximizeScore = maximize 79 | let maximizeMetrics = ["auc", "aucpr", "map", "ndcg"] 80 | let maximizeAtNMetrics = ["auc@", "aucpr@", "map@", "ndcg@"] 81 | 82 | let metric = metricName.split(separator: "-", maxSplits: 1).last! 83 | 84 | if maximizeAtNMetrics.contains(where: { metric.hasPrefix($0) }) { 85 | maximizeScore = true 86 | } 87 | 88 | if maximizeMetrics.contains(where: { $0 == metric.split(separator: ":").first! }) { 89 | maximizeScore = true 90 | } 91 | 92 | return maximizeScore 93 | } 94 | 95 | /// - Parameter dataName: Name of data used for early stopping. 96 | /// - Parameter metricName: Metric to look for. 97 | /// - Parameter stoppingRounds: Number of rounds to check improvence for. 98 | /// - Parameter state: Initial state. 99 | /// - Parameter verbose: Print on new best or stopping. 100 | public required init( 101 | dataName: String, 102 | metricName: String, 103 | stoppingRounds: Int, 104 | state: State, 105 | verbose: Bool = false 106 | ) { 107 | self.dataName = dataName 108 | self.metricName = metricName 109 | self.stoppingRounds = stoppingRounds 110 | self.verbose = verbose 111 | self.state = state 112 | } 113 | 114 | /// - Parameter dataName: Name of data used for early stopping. 115 | /// - Parameter metricName: Metric to look for. 116 | /// - Parameter stoppingRounds: Number of rounds to check improvence for. 117 | /// - Parameter maximize: If metric should be maximized, minimzed otherwise. 118 | /// - Parameter verbose: Print on new best or stopping. 119 | public convenience init( 120 | dataName: String, 121 | metricName: String, 122 | stoppingRounds: Int, 123 | maximize: Bool = false, 124 | verbose: Bool = false 125 | ) { 126 | let maximizeScore = EarlyStopping.shouldMaximize(maximize: maximize, metricName: metricName) 127 | self.init( 128 | dataName: dataName, 129 | metricName: metricName, 130 | stoppingRounds: stoppingRounds, 131 | state: ( 132 | maximizeScore: maximizeScore, 133 | bestIteration: 0, 134 | bestScore: maximizeScore ? -Float.greatestFiniteMagnitude : Float.greatestFiniteMagnitude, 135 | bestMsg: EarlyStopping.formatMessage(bestIteration: 0, bestEvaluation: [:]) 136 | ), 137 | verbose: verbose 138 | ) 139 | } 140 | 141 | /// - Parameter dataName: Name of data used for early stopping. 142 | /// - Parameter metricName: Metric to look for. 143 | /// - Parameter stoppingRounds: Number of rounds to check improvence for. 144 | /// - Parameter maximize: If metric should be maximized, minimzed otherwise. 145 | /// - Parameter booster: Booster to load state from. 146 | /// - Parameter verbose: Print on new best or stopping. 147 | public convenience init( 148 | dataName: String, 149 | metricName: String, 150 | stoppingRounds: Int, 151 | maximize: Bool, 152 | booster: Booster, 153 | verbose: Bool = false 154 | ) throws { 155 | let maximizeScore = EarlyStopping.shouldMaximize(maximize: maximize, metricName: metricName) 156 | 157 | var state = ( 158 | maximizeScore: maximizeScore, 159 | bestIteration: 0, 160 | bestScore: maximizeScore ? -Float.greatestFiniteMagnitude : Float.greatestFiniteMagnitude, 161 | bestMsg: EarlyStopping.formatMessage(bestIteration: 0, bestEvaluation: [:]) 162 | ) 163 | 164 | if let bestIteration = try booster.attribute(name: "best_iteration") { 165 | state.bestIteration = Int(bestIteration)! 166 | } 167 | 168 | if let bestScore = try booster.attribute(name: "best_score") { 169 | state.bestScore = Float(bestScore)! 170 | } 171 | 172 | if let bestMsg = try booster.attribute(name: "best_msg") { 173 | state.bestMsg = bestMsg 174 | } 175 | 176 | self.init( 177 | dataName: dataName, 178 | metricName: metricName, 179 | stoppingRounds: stoppingRounds, 180 | state: state, 181 | verbose: verbose 182 | ) 183 | } 184 | 185 | /// Call used in Booster training. 186 | /// 187 | /// - Parameter booster: Booster. 188 | /// - Parameter iteration: Current iteration. 189 | /// - Parameter evaluation: Dictionary with evaluations. 190 | public func call( 191 | booster: Booster? = nil, 192 | iteration: Int, 193 | evaluation: Evaluation? 194 | ) throws -> AfterIterationOutput { 195 | guard let evaluation = evaluation else { 196 | throw ValueError.runtimeError("Evaluation data can not be nil.") 197 | } 198 | 199 | guard let data = evaluation[dataName] else { 200 | throw ValueError.runtimeError("Name of DMatrix \(dataName) not found in evaluation.") 201 | } 202 | 203 | guard let stringScore = data[metricName] else { 204 | throw ValueError.runtimeError("Name of metric \(metricName) not found in evaluation.") 205 | } 206 | 207 | let score = Float(stringScore)! 208 | 209 | if (state.maximizeScore && score > state.bestScore) || (!state.maximizeScore && score < state.bestScore) { 210 | state.bestScore = score 211 | state.bestIteration = iteration 212 | state.bestMsg = EarlyStopping.formatMessage(bestIteration: iteration, bestEvaluation: evaluation) 213 | 214 | if let booster = booster { 215 | try booster.set(attribute: "best_score", value: String(state.bestScore)) 216 | try booster.set(attribute: "best_iteration", value: String(state.bestIteration)) 217 | try booster.set(attribute: "best_msg", value: state.bestMsg) 218 | } 219 | 220 | if verbose { 221 | log(state.bestMsg) 222 | } 223 | } else if iteration - state.bestIteration >= stoppingRounds { 224 | if verbose { 225 | log("Stopping at iteration \(iteration): " + state.bestMsg) 226 | } 227 | 228 | return .stop 229 | } 230 | 231 | return .next 232 | } 233 | 234 | /// Call used in cross-validation training. 235 | /// 236 | /// - Parameter iteration: Current iteration. 237 | /// - Parameter evaluation: Cross-validation evaluation. 238 | public func call( 239 | iteration: Int, 240 | evaluation: CVEvaluation 241 | ) throws -> AfterIterationOutput { 242 | let name = dataName + "-test-" + metricName + "-mean" 243 | 244 | guard let scores = evaluation[name] else { 245 | throw ValueError.runtimeError("Name \(name) not found in evaluation \(evaluation).") 246 | } 247 | 248 | guard let score = scores.last else { 249 | throw ValueError.runtimeError("Scores for \(name) are empty.") 250 | } 251 | 252 | if (state.maximizeScore && score > state.bestScore) || (!state.maximizeScore && score < state.bestScore) { 253 | state.bestScore = score 254 | state.bestIteration = iteration 255 | state.bestMsg = EarlyStopping.formatMessage(bestIteration: iteration, bestEvaluation: evaluation) 256 | 257 | if verbose { 258 | log(state.bestMsg) 259 | } 260 | } else if iteration - state.bestIteration >= stoppingRounds { 261 | if verbose { 262 | log("Stopping at iteration \(iteration): " + state.bestMsg) 263 | } 264 | 265 | return .stop 266 | } 267 | 268 | return .next 269 | } 270 | } 271 | 272 | /// Class used for variable learning rate feature. 273 | public class VariableLearningRate: Callback { 274 | /// Name of the callback. 275 | public var name: String = "VariableLearningRate" 276 | 277 | /// Position of the execution. 278 | public var execute: [LoopPosition] = [.before] 279 | 280 | /// Typealias for learningRateFunction. 281 | public typealias Function = (Int, Int) -> String 282 | 283 | /// Number of iterations in training. 284 | var iterations: Int 285 | 286 | /// List with learning rates. 287 | var learningRates: [String]? 288 | 289 | /// Function returning learning rate as string based on current iteration and maximum number of iterations. 290 | var learningRateFunction: Function? 291 | 292 | /// Initialize VariableLearningRate by array of learning rates. 293 | /// 294 | /// - Precondition: iterations == learningRates.count. 295 | /// - Parameter learningRates: Array of learning rates that will be accessed at every iteration. 296 | /// - Parameter iterations: Number of iteration in training. 297 | public init( 298 | learningRates: [String], 299 | iterations: Int 300 | ) { 301 | precondition( 302 | iterations == learningRates.count, 303 | "Learning rates count must be equal to iterations." 304 | ) 305 | 306 | self.learningRates = learningRates 307 | self.iterations = iterations 308 | } 309 | 310 | /// Initialize VariableLearningRate with function generating learning rate. 311 | /// 312 | /// - Parameter learningRate: Function that will return learning rate at each iteration. 313 | /// - Parameter iterations: Number of iteration in training. 314 | public init( 315 | learningRate: @escaping Function, 316 | iterations: Int 317 | ) { 318 | learningRateFunction = learningRate 319 | self.iterations = iterations 320 | } 321 | 322 | /// - Parameter booster: Booster. 323 | /// - Parameter iteration: Current iteration. 324 | /// - Parameter evaluation: Dictionary with evaluations. 325 | public func call( 326 | booster: Booster?, 327 | iteration: Int, 328 | evaluation _: Evaluation? 329 | ) throws -> AfterIterationOutput { 330 | let newLearningRate: String = { 331 | if let rates = learningRates { 332 | return rates[iteration] 333 | } else { 334 | return learningRateFunction!(iteration, iterations) 335 | } 336 | }() 337 | 338 | guard let booster = booster else { 339 | throw ValueError.runtimeError("Booster is required to set learning rate.") 340 | } 341 | 342 | try booster.set(parameter: "learning_rate", value: newLearningRate) 343 | 344 | return .next 345 | } 346 | } 347 | 348 | /// Typealias for function called before each iteration at training. 349 | public typealias BeforeIteration = (Booster, Int) throws -> AfterIterationOutput 350 | 351 | /// Typealias for function called after each iteration at training. 352 | public typealias AfterIteration = (Booster, Int, Evaluation?, [AfterIterationOutput]) throws -> AfterIterationOutput 353 | 354 | /// Default before iteration function, basically does nothing. 355 | public let DefaultBeforeIteration: BeforeIteration = { _, _ in .next } 356 | 357 | /// Default after iteration function, basically does nothing. 358 | public let DefaultAfterIteration: AfterIteration = { _, _, _, _ in .next } 359 | 360 | extension Booster { 361 | /// Train booster. 362 | /// 363 | /// - Parameter iterations: Number of training iterations, but training can be stopped early. 364 | /// - Parameter startIteration: N. of starting iteration. 365 | /// - Parameter trainingData: Data to train on. 366 | /// - Parameter evaluationData: Data to evaluate on, if provided. 367 | /// - Parameter evaluationFunction: Custom evaluation function. 368 | /// - Parameter beforeIteration: Callback called before each iteration. 369 | /// - Parameter callbacks: Array of callbacks called at each iteration. 370 | /// - Parameter afterIteration: Callback called after each iteration. 371 | public func train( 372 | iterations: Int, 373 | startIteration: Int? = nil, 374 | trainingData: Data, 375 | objectiveFunction: ObjectiveFunction? = nil, 376 | evaluationData: [Data] = [], 377 | evaluationFunction: EvaluationFunction? = nil, 378 | beforeIteration: BeforeIteration = DefaultBeforeIteration, 379 | callbacks: [Callback] = [], 380 | afterIteration: AfterIteration = DefaultAfterIteration 381 | ) throws { 382 | var version = try loadRabitCheckpoint() 383 | let startIteration = startIteration ?? Int(version / 2) 384 | 385 | training: for iteration in startIteration ..< iterations { 386 | var outputs = [AfterIterationOutput]() 387 | 388 | switch try beforeIteration( 389 | self, 390 | iteration 391 | ) { 392 | case .stop: 393 | break training 394 | case .next: 395 | break 396 | } 397 | 398 | outputs.append(contentsOf: try callbacks.filter { $0.execute.contains(.before) }.map { 399 | try $0.call(booster: self, iteration: iteration, evaluation: nil) 400 | }) 401 | 402 | if outputs.contains(where: { $0 == .stop }) { 403 | break training 404 | } 405 | 406 | if let objectiveFunction = objectiveFunction { 407 | try update( 408 | data: trainingData, 409 | objective: objectiveFunction 410 | ) 411 | } else { 412 | try update( 413 | iteration: iteration, 414 | data: trainingData 415 | ) 416 | } 417 | 418 | let evaluation = 419 | evaluationData.isEmpty ? nil : try evaluate( 420 | iteration: iteration, data: evaluationData, function: evaluationFunction 421 | ) 422 | 423 | outputs.append(contentsOf: try callbacks.filter { $0.execute.contains(.after) }.map { 424 | try $0.call(booster: self, iteration: iteration, evaluation: evaluation) 425 | }) 426 | 427 | try saveRabitCheckpoint() 428 | version += 1 429 | 430 | switch try afterIteration( 431 | self, 432 | iteration, 433 | evaluation, 434 | outputs 435 | ) { 436 | case .stop: 437 | break training 438 | case .next: 439 | break 440 | } 441 | 442 | if outputs.contains(where: { $0 == .stop }) { 443 | break training 444 | } 445 | } 446 | } 447 | } 448 | -------------------------------------------------------------------------------- /Sources/XGBoost/DMatrix.swift: -------------------------------------------------------------------------------- 1 | import CXGBoost 2 | 3 | /// Backward compatible alias for Data 4 | public typealias Data = DMatrix 5 | 6 | /// Protocol used where [Float] values are expected. 7 | public protocol FloatData { 8 | /// - Returns: Value as [Float] 9 | func data() throws -> [Float] 10 | } 11 | 12 | /// Protocol used where [Int32] values are expected. 13 | public protocol Int32Data { 14 | /// - Returns: Value as [Int32] 15 | func data() throws -> [Int32] 16 | } 17 | 18 | /// Protocol used where [UInt32] values are expected. 19 | public protocol UInt32Data { 20 | /// - Returns: Value as [UInt32] 21 | func data() throws -> [UInt32] 22 | } 23 | 24 | /// Protocol declaring that structure has shape. 25 | public protocol ShapeData { 26 | /// - Returns: Shape of self. 27 | func dataShape() throws -> Shape 28 | } 29 | 30 | extension ShapeData { 31 | /// - Returns: Whether shape indicates flat structure. 32 | public func isFlat() throws -> Bool { 33 | try dataShape().rank == 1 34 | } 35 | } 36 | 37 | /// Shape of data structure. 38 | public typealias Shape = [Int] 39 | 40 | public extension Shape { 41 | var rank: Int { 42 | count 43 | } 44 | 45 | init(_ elements: Int...) { 46 | self = elements 47 | } 48 | } 49 | 50 | /// Data class used with Booster. 51 | /// 52 | /// Data is encapsulation of DMatrixHandle, internal structure used by XGBoost, 53 | /// which is optimized for both memory efficiency and training speed. 54 | public class DMatrix { 55 | /// Name of dataset, for example, "train". 56 | public var name: String 57 | 58 | var _features: [Feature]? 59 | 60 | /// Pointer to underlying DMatrixHandle. 61 | public let dmatrix: DMatrixHandle? 62 | 63 | /// Initialize Data with an existing DMatrixHandle pointer. 64 | /// 65 | /// - Parameter name: Name of dataset. 66 | /// - Parameter dmatrix: DMatrixHandle pointer. 67 | /// - Parameter features: Array describing features in this dataset. 68 | /// - Parameter label: Sets label after DMatrix initialization. 69 | /// - Parameter weight: Sets weight after DMatrix initialization. 70 | /// - Parameter baseMargin: Sets baseMargin after DMatrix initialization. 71 | /// - Parameter labelLowerBound: Sets labelLowerBound after DMatrix initialization. 72 | /// - Parameter labelUpperBound: Sets labelUpperBound after DMatrix initialization. 73 | public init( 74 | name: String, 75 | dmatrix: DMatrixHandle?, 76 | features: [Feature]?, 77 | label: FloatData?, 78 | weight: FloatData?, 79 | baseMargin: FloatData?, 80 | labelLowerBound: FloatData?, 81 | labelUpperBound: FloatData? 82 | ) throws { 83 | self.name = name 84 | self.dmatrix = dmatrix 85 | 86 | if let features = features { 87 | try set(features: features) 88 | } 89 | 90 | if let label = label { 91 | try set(field: .label, values: label) 92 | } 93 | 94 | if let weight = weight { 95 | try set(field: .weight, values: weight) 96 | } 97 | 98 | if let baseMargin = baseMargin { 99 | try set(field: .baseMargin, values: baseMargin) 100 | } 101 | 102 | if let labelLowerBound = labelLowerBound { 103 | try set(field: .labelLowerBound, values: labelLowerBound) 104 | } 105 | 106 | if let labelUpperBound = labelUpperBound { 107 | try set(field: .labelUpperBound, values: labelUpperBound) 108 | } 109 | } 110 | 111 | /// Initialize Data. 112 | /// 113 | /// - Parameter name: Name of dataset. 114 | /// - Parameter from: Values source. 115 | /// - Parameter shape: Shape of resulting DMatrixHandle, needs to be of rank two, e.g. [row, column]. 116 | /// - Parameter features: Names and types of features. 117 | /// - Parameter label: Sets label after DMatrix initialization. 118 | /// - Parameter weight: Sets weight after DMatrix initialization. 119 | /// - Parameter baseMargin: Sets baseMargin after DMatrix initialization. 120 | /// - Parameter labelLowerBound: Sets labelLowerBound after DMatrix initialization. 121 | /// - Parameter labelUpperBound: Sets labelUpperBound after DMatrix initialization. 122 | /// - Parameter missingValue: Value in the input data which needs to be present as a missing value. 123 | /// - Parameter threads: Number of threads to use for loading data when parallelization is applicable. If 0, uses maximum threads available on the system. 124 | public convenience init( 125 | name: String, 126 | from data: FloatData, 127 | shape: Shape, 128 | features: [Feature]? = nil, 129 | label: FloatData? = nil, 130 | weight: FloatData? = nil, 131 | baseMargin: FloatData? = nil, 132 | labelLowerBound: FloatData? = nil, 133 | labelUpperBound: FloatData? = nil, 134 | missingValue: Float = Float.greatestFiniteMagnitude, 135 | threads: Int = 0 136 | ) throws { 137 | if shape.rank != 2 { 138 | throw ValueError.runtimeError("Invalid shape \(shape), required rank 2.") 139 | } 140 | 141 | var dmatrix: DMatrixHandle? 142 | try safe { 143 | XGDMatrixCreateFromMat_omp( 144 | try data.data(), 145 | UInt64(shape[0]), 146 | UInt64(shape[1]), 147 | missingValue, 148 | &dmatrix, 149 | Int32(threads) 150 | ) 151 | } 152 | 153 | try self.init( 154 | name: name, 155 | dmatrix: dmatrix, 156 | features: features, 157 | label: label, 158 | weight: weight, 159 | baseMargin: baseMargin, 160 | labelLowerBound: labelLowerBound, 161 | labelUpperBound: labelUpperBound 162 | ) 163 | } 164 | 165 | /// Initialize Data. 166 | /// 167 | /// - Parameter name: Name of dataset. 168 | /// - Parameter from: Data with shape comfortance. 169 | /// - Parameter features: Names and types of features. 170 | /// - Parameter label: Sets label after DMatrix initialization. 171 | /// - Parameter weight: Sets weight after DMatrix initialization. 172 | /// - Parameter baseMargin: Sets baseMargin after DMatrix initialization. 173 | /// - Parameter labelLowerBound: Sets labelLowerBound after DMatrix initialization. 174 | /// - Parameter labelUpperBound: Sets labelUpperBound after DMatrix initialization. 175 | /// - Parameter missingValue: Value in the input data which needs to be present as a missing value. 176 | /// - Parameter threads: Number of threads to use for loading data when parallelization is applicable. If 0, uses maximum threads available on the system. 177 | public convenience init( 178 | name: String, 179 | from: FloatData & ShapeData, 180 | features: [Feature]? = nil, 181 | label: FloatData? = nil, 182 | weight: FloatData? = nil, 183 | baseMargin: FloatData? = nil, 184 | labelLowerBound: FloatData? = nil, 185 | labelUpperBound: FloatData? = nil, 186 | missingValue _: Float = Float.greatestFiniteMagnitude, 187 | threads: Int = 0 188 | ) throws { 189 | try self.init( 190 | name: name, 191 | from: from, 192 | shape: try from.dataShape(), 193 | features: features, 194 | label: label, 195 | weight: weight, 196 | baseMargin: baseMargin, 197 | labelLowerBound: labelLowerBound, 198 | labelUpperBound: labelUpperBound, 199 | threads: threads 200 | ) 201 | } 202 | 203 | /// Initialize Data from file. 204 | /// 205 | /// - Parameter name: Name of dataset. 206 | /// - Parameter from: File to laod from. 207 | /// - Parameter format: Format of input file. 208 | /// - Parameter useCache: Use external memory. 209 | /// - Parameter features: Names and types of features. 210 | /// - Parameter labelColumn: Which column is for label. 211 | /// - Parameter label: Sets label after DMatrix initialization. 212 | /// - Parameter weight: Sets weight after DMatrix initialization. 213 | /// - Parameter baseMargin: Sets baseMargin after DMatrix initialization. 214 | /// - Parameter labelLowerBound: Sets labelLowerBound after DMatrix initialization. 215 | /// - Parameter labelUpperBound: Sets labelUpperBound after DMatrix initialization. 216 | /// - Parameter silent: Whether print messages during construction. 217 | /// - Parameter fileQuery: Additional parameters that will be appended to the file path as query. 218 | public convenience init( 219 | name: String, 220 | from file: String, 221 | format: DataFormat, 222 | useCache: Bool = false, 223 | features: [Feature]? = nil, 224 | labelColumn: Int? = nil, 225 | label: FloatData? = nil, 226 | weight: FloatData? = nil, 227 | baseMargin: FloatData? = nil, 228 | labelLowerBound: FloatData? = nil, 229 | labelUpperBound: FloatData? = nil, 230 | silent: Bool = true, 231 | fileQuery: [String] = [] 232 | ) throws { 233 | var file = file 234 | var fileQuery = fileQuery 235 | 236 | switch format { 237 | case .csv, .libsvm: 238 | fileQuery.append("format=\(format)") 239 | case .binary: 240 | break 241 | } 242 | 243 | if let labelColumn = labelColumn { 244 | fileQuery.append("label_column=\(labelColumn)") 245 | } 246 | 247 | if useCache { 248 | if format != .libsvm { 249 | throw ValueError.runtimeError("Cache currently only support convert from libsvm file.") 250 | } 251 | 252 | file += "#\(name).cache" 253 | } 254 | 255 | var dmatrix: DMatrixHandle? 256 | try safe { 257 | XGDMatrixCreateFromFile( 258 | file + (fileQuery.isEmpty ? "" : "?\(fileQuery.joined(separator: "&"))"), 259 | silent ? 1 : 0, 260 | &dmatrix 261 | ) 262 | } 263 | 264 | try self.init( 265 | name: name, 266 | dmatrix: dmatrix, 267 | features: features, 268 | label: label, 269 | weight: weight, 270 | baseMargin: baseMargin, 271 | labelLowerBound: labelLowerBound, 272 | labelUpperBound: labelUpperBound 273 | ) 274 | } 275 | 276 | deinit { 277 | try! safe { 278 | XGDMatrixFree(dmatrix) 279 | } 280 | } 281 | 282 | /// Save DMatrixHandle to binary file. 283 | /// 284 | /// - Parameter to: File path. 285 | /// - Parameter silent: Whether print messages during construction. 286 | public func save( 287 | to path: String, 288 | silent: Bool = true 289 | ) throws { 290 | try safe { 291 | XGDMatrixSaveBinary( 292 | dmatrix, 293 | path, 294 | silent ? 1 : 0 295 | ) 296 | } 297 | } 298 | 299 | /// Save feature map compatible with XGBoost`s inputs. 300 | /// 301 | /// - Parameter to: Path where feature map will be saved. 302 | public func saveFeatureMap( 303 | to path: String 304 | ) throws { 305 | try features().saveFeatureMap(to: path) 306 | } 307 | 308 | /// Load previously saved feature map. 309 | /// 310 | /// - Parameter from: Path to the feature map. 311 | public func loadFeatureMap( 312 | from path: String 313 | ) throws { 314 | try set(features: try [Feature](fromFeatureMap: path)) 315 | } 316 | 317 | /// - Returns: The number of rows. 318 | public func rowCount() throws -> Int { 319 | var count: UInt64 = 0 320 | try! safe { 321 | XGDMatrixNumRow( 322 | dmatrix, 323 | &count 324 | ) 325 | } 326 | return Int(count) 327 | } 328 | 329 | /// - Returns: The number of columns. 330 | public func columnCount() throws -> Int { 331 | var count: UInt64 = 0 332 | try! safe { 333 | XGDMatrixNumCol( 334 | dmatrix, 335 | &count 336 | ) 337 | } 338 | return Int(count) 339 | } 340 | 341 | /// - Returns: The shape of dmatrix, i.e. (rowCount(), columnCount()). 342 | public func shape() throws -> Shape { 343 | Shape( 344 | try rowCount(), 345 | try columnCount() 346 | ) 347 | } 348 | 349 | /// Slice and return a new Data that only contains `indexes`. 350 | /// 351 | /// - Parameter indexes: Array of indices to be selected. 352 | /// - Parameter newName: New name of the returned Data. 353 | /// - Parameter allowGroups: Allow slicing of a matrix with a groups attribute. 354 | /// - Returns: A new dmatrix class containing only selected indexes. 355 | public func slice( 356 | indexes: Int32Data, 357 | newName: String? = nil, 358 | allowGroups: Bool = false 359 | ) throws -> DMatrix { 360 | let indexes = try indexes.data() 361 | var slicedDmatrix: DMatrixHandle? 362 | 363 | try safe { 364 | XGDMatrixSliceDMatrixEx( 365 | dmatrix, 366 | indexes, 367 | UInt64(indexes.count), 368 | &slicedDmatrix, 369 | allowGroups ? 1 : 0 370 | ) 371 | } 372 | 373 | return try DMatrix( 374 | name: newName ?? name, 375 | dmatrix: slicedDmatrix, 376 | features: _features, 377 | label: nil, 378 | weight: nil, 379 | baseMargin: nil, 380 | labelLowerBound: nil, 381 | labelUpperBound: nil 382 | ) 383 | } 384 | 385 | /// Slice and return a new Data that only indexes from given range.. 386 | /// 387 | /// - Parameter indexes: Range of indices to be selected. 388 | /// - Parameter newName: New name of the returned Data. 389 | /// - Parameter allowGroups: Allow slicing of a matrix with a groups attribute. 390 | /// - Returns: A new DMatrix class containing only selected indexes. 391 | public func slice( 392 | indexes: Range, 393 | newName: String? = nil, 394 | allowGroups: Bool = false 395 | ) throws -> DMatrix { 396 | try slice( 397 | indexes: Array(indexes), 398 | newName: newName, 399 | allowGroups: allowGroups 400 | ) 401 | } 402 | 403 | /// Set uint type property into the DMatrixHandle. 404 | /// 405 | /// - Parameter field: The field name of the information. 406 | /// - Parameter values: The array of data to be set. 407 | public func setUIntInfo( 408 | field: String, 409 | values: UInt32Data 410 | ) throws { 411 | let values = try values.data() 412 | try safe { 413 | XGDMatrixSetUIntInfo( 414 | dmatrix, 415 | field, 416 | values, 417 | UInt64(values.count) 418 | ) 419 | } 420 | } 421 | 422 | /// Set float type property into the DMatrixHandle. 423 | /// 424 | /// - Parameter field: The field name of the information. 425 | /// - Parameter values: The array of data to be set. 426 | public func setFloatInfo( 427 | field: String, 428 | values: FloatData 429 | ) throws { 430 | let values = try values.data() 431 | try safe { 432 | XGDMatrixSetFloatInfo( 433 | dmatrix, 434 | field, 435 | values, 436 | UInt64(values.count) 437 | ) 438 | } 439 | } 440 | 441 | /// Get unsigned integer property from the DMatrixHandle. 442 | /// 443 | /// - Parameter field: The field name of the information. 444 | /// - Returns: An array of unsigned integer information of the data. 445 | public func getUIntInfo( 446 | field: String 447 | ) throws -> [UInt32] { 448 | let outLenght = UnsafeMutablePointer.allocate(capacity: 1) 449 | let outResult = UnsafeMutablePointer?>.allocate(capacity: 1) 450 | 451 | try safe { 452 | XGDMatrixGetUIntInfo(dmatrix, field, outLenght, outResult) 453 | } 454 | 455 | return (0 ..< Int(outLenght.pointee)).lazy.map { outResult.pointee![$0] } 456 | } 457 | 458 | /// Get float property from the DMatrixHandle. 459 | /// 460 | /// - Parameter field: The field name of the information. 461 | /// - Returns: An array of float information of the data. 462 | public func getFloatInfo( 463 | field: String 464 | ) throws -> [Float] { 465 | let outLenght = UnsafeMutablePointer.allocate(capacity: 1) 466 | let outResult = UnsafeMutablePointer?>.allocate(capacity: 1) 467 | 468 | try safe { 469 | XGDMatrixGetFloatInfo(dmatrix, field, outLenght, outResult) 470 | } 471 | 472 | return (0 ..< Int(outLenght.pointee)).lazy.map { outResult.pointee![$0] } 473 | } 474 | 475 | /// Set uint type property into the DMatrixHandle. 476 | /// 477 | /// - Parameter field: The field name of the information. 478 | /// - Parameter values: The array of data to be set. 479 | public func set( 480 | field: UIntField, 481 | values: UInt32Data 482 | ) throws { 483 | let values = try values.data() 484 | if field == .group, try values.reduce(0, +) != rowCount() { 485 | throw ValueError.runtimeError("The sum of groups must equal to the number of rows in the dmatrix.") 486 | } 487 | 488 | try setUIntInfo(field: field.rawValue, values: values) 489 | } 490 | 491 | /// Set float type property into the DMatrixHandle. 492 | /// 493 | /// - Parameter field: The field name of the information. 494 | /// - Parameter values: The array of data to be set. 495 | public func set( 496 | field: FloatField, 497 | values: FloatData 498 | ) throws { 499 | let values = try values.data() 500 | switch field { 501 | case .label, .weight, .labelLowerBound, .labelUpperBound: 502 | if try values.count != rowCount() { 503 | throw ValueError.runtimeError("The count of values must equal to the number of rows in the dmatrix.") 504 | } 505 | case .baseMargin: 506 | break 507 | } 508 | 509 | try setFloatInfo(field: field.rawValue, values: values) 510 | } 511 | 512 | /// Save names and types of features. 513 | /// 514 | /// - Parameter features: Optional array of features. 515 | public func set( 516 | features: [Feature]? 517 | ) throws { 518 | guard let features = features else { 519 | _features = nil 520 | return 521 | } 522 | 523 | let columnCount = try self.columnCount() 524 | 525 | if features.count != columnCount { 526 | throw ValueError.runtimeError("Features count \(features.count) != data count \(columnCount).") 527 | } 528 | 529 | let names = features.map(\.name) 530 | 531 | if names.count != Set(names).count { 532 | throw ValueError.runtimeError("Feature names must be unique.") 533 | } 534 | 535 | if !names.allSatisfy({ !$0.contains("[") && !$0.contains("]") && !$0.contains("<") }) { 536 | throw ValueError.runtimeError("Feature names must not contain [, ] or <.") 537 | } 538 | 539 | _features = features 540 | } 541 | 542 | /// - Returns: Features of data, generates universal names with quantitative type if not previously set by user. 543 | public func features(defaultPrefix: String = "f") throws -> [Feature] { 544 | if _features == nil { 545 | _features = try (0 ..< columnCount()).map { Feature(name: "\(defaultPrefix)\($0)", type: .quantitative) } 546 | } 547 | 548 | return _features! 549 | } 550 | 551 | /// - Returns: An array of float information of the data. 552 | public func get(field: FloatField) throws -> [Float] { 553 | try getFloatInfo(field: field.rawValue) 554 | } 555 | 556 | /// - Returns: An array of uint information of the data. 557 | public func get(field: UIntField) throws -> [UInt32] { 558 | try getUIntInfo(field: field.rawValue) 559 | } 560 | } 561 | -------------------------------------------------------------------------------- /Sources/XGBoost/Booster.swift: -------------------------------------------------------------------------------- 1 | import CXGBoost 2 | import Foundation 3 | 4 | /// Alias for backward compatibility. 5 | public typealias XGBoost = Booster 6 | 7 | /// Typealias for objective function. 8 | public typealias ObjectiveFunction = (ArrayWithShape, Data) throws -> (gradient: [Float], hessian: [Float]) 9 | 10 | /// Typealias for evaluation function. 11 | public typealias EvaluationFunction = (ArrayWithShape, Data) throws -> (String, String) 12 | 13 | /// Booster model. 14 | /// 15 | /// Encapsulates BoosterHandle, the model of xgboost, that contains low level routines for 16 | /// training, prediction and evaluation. 17 | public class Booster { 18 | var features: [Feature]? 19 | var type: BoosterType? 20 | 21 | /// Pointer to underlying BoosterHandle. 22 | public var booster: BoosterHandle? 23 | 24 | /// Version of underlying XGBoost system library. 25 | public static var systemLibraryVersion = Version( 26 | major: Int(XGBOOST_VER_MAJOR), 27 | minor: Int(XGBOOST_VER_MINOR), 28 | patch: Int(XGBOOST_VER_PATCH) 29 | ) 30 | 31 | /// Register callback function for LOG(INFO) messages. 32 | /// 33 | /// - Parameter call: Function to be called with C-String as parameter. Use String(cString: $0!) co convert it into Swift string. 34 | public static func registerLogCallback( 35 | _ call: (@convention(c) (UnsafePointer?) -> Void)? 36 | ) throws { 37 | try safe { 38 | XGBRegisterLogCallback(call) 39 | } 40 | } 41 | 42 | /// Initialize Booster with an existing BoosterHandle pointer. 43 | /// 44 | /// - Parameter booster: BoosterHandle pointer. 45 | public init( 46 | booster: BoosterHandle? 47 | ) { 48 | self.booster = booster 49 | } 50 | 51 | /// Initialize Booster from buffer. 52 | /// 53 | /// - Parameter buffer: Model serialized as buffer. 54 | public convenience init( 55 | buffer: BufferModel 56 | ) throws { 57 | var booster: BoosterHandle? 58 | try safe { 59 | XGBoosterUnserializeFromBuffer( 60 | &booster, 61 | buffer.data, 62 | buffer.length 63 | ) 64 | } 65 | self.init(booster: booster) 66 | } 67 | 68 | /// Initialize new Booster. 69 | /// 70 | /// - Parameter with: Data that will be cached. 71 | /// - Parameter from: Loads model from path. 72 | /// - Parameter config: Loads model from config. 73 | /// - Parameter parameters: Array of parameters to be set. 74 | /// - Parameter validateParameters: If true, parameters will be valided. This basically adds parameter validate_parameters=1. 75 | public convenience init( 76 | with data: [Data] = [], 77 | from path: String? = nil, 78 | config: String? = nil, 79 | parameters: [Parameter] = [], 80 | validateParameters: Bool = true 81 | ) throws { 82 | let pointees = data.map(\.dmatrix) 83 | 84 | var booster: BoosterHandle? 85 | try safe { 86 | XGBoosterCreate(pointees, UInt64(pointees.count), &booster) 87 | } 88 | self.init(booster: booster) 89 | 90 | if let path = path { 91 | try load(model: path) 92 | } 93 | 94 | if let config = config { 95 | try load(config: config) 96 | } 97 | 98 | if validateParameters { 99 | try set(parameter: "validate_parameters", value: "1") 100 | } 101 | 102 | for parameter in parameters { 103 | try set(parameter: parameter.name, value: parameter.value) 104 | 105 | if parameter.name == "booster" { 106 | type = BoosterType(rawValue: parameter.value) 107 | 108 | if type == nil { 109 | throw ValueError.runtimeError("Unknown booster type \(parameter.value).") 110 | } 111 | } 112 | } 113 | 114 | try validate(data: data) 115 | } 116 | 117 | deinit { 118 | try! safe { 119 | XGBoosterFree(booster) 120 | } 121 | } 122 | 123 | /// Serializes and unserializes booster to reset state and free training memory. 124 | public func reset() throws { 125 | let snapshot = try serialized() 126 | 127 | try safe { 128 | XGBoosterFree(booster) 129 | } 130 | 131 | try safe { 132 | XGBoosterUnserializeFromBuffer( 133 | &booster, 134 | snapshot.data, 135 | snapshot.length 136 | ) 137 | } 138 | } 139 | 140 | /// - Returns: Booster's internal configuration in a JSON string. 141 | public func config() throws -> String { 142 | let outLenght = UnsafeMutablePointer.allocate(capacity: 1) 143 | let outResult = UnsafeMutablePointer?>.allocate(capacity: 1) 144 | 145 | try safe { 146 | XGBoosterSaveJsonConfig(booster, outLenght, outResult) 147 | } 148 | 149 | return String(cString: outResult.pointee!) 150 | } 151 | 152 | /// - Returns: Attributes stored in the Booster as a dictionary. 153 | public func attributes() throws -> [String: String] { 154 | let outLenght = UnsafeMutablePointer.allocate(capacity: 1) 155 | let outResult = UnsafeMutablePointer?>?>.allocate(capacity: 1) 156 | 157 | try safe { 158 | XGBoosterGetAttrNames(booster, outLenght, outResult) 159 | } 160 | 161 | var attributes = [String: String]() 162 | let names = (0 ..< Int(outLenght.pointee)).lazy.map { String(cString: outResult.pointee![$0]!) } 163 | 164 | for name in names { 165 | attributes[name] = try attribute(name: name)! 166 | } 167 | 168 | return attributes 169 | } 170 | 171 | /// Predict from data. 172 | /// 173 | /// - Parameter from: Data to predict from. 174 | /// - Parameter outputMargin: Whether to output the raw untransformed margin value. 175 | /// - Parameter treeLimit: Limit number of trees in the prediction. Zero means use all trees. 176 | /// - Parameter predictionLeaf: Each record indicating the predicted leaf index of each sample in each tree. 177 | /// - Parameter predictionContributions: Each record indicating the feature contributions (SHAP values) for that prediction. 178 | /// - Parameter approximateContributions: Approximate the contributions of each feature. 179 | /// - Parameter predictionInteractions: Indicate the SHAP interaction values for each pair of features. 180 | /// - Parameter training: Whether the prediction will be used for traning. 181 | /// - Parameter validateFeatures: Validate booster and data features. 182 | public func predict( 183 | from data: DMatrix, 184 | outputMargin: Bool = false, 185 | treeLimit: UInt32 = 0, 186 | predictionLeaf: Bool = false, 187 | predictionContributions: Bool = false, 188 | approximateContributions: Bool = false, 189 | predictionInteractions: Bool = false, 190 | training: Bool = false, 191 | validateFeatures: Bool = true 192 | ) throws -> ArrayWithShape { 193 | if validateFeatures { 194 | try validate(data: data) 195 | } 196 | 197 | var optionMask: Int32 = 0x00 198 | 199 | if outputMargin { 200 | optionMask |= 0x01 201 | } 202 | 203 | if predictionLeaf { 204 | optionMask |= 0x02 205 | } 206 | 207 | if predictionContributions { 208 | optionMask |= 0x04 209 | } 210 | 211 | if approximateContributions { 212 | optionMask |= 0x08 213 | } 214 | 215 | if predictionInteractions { 216 | optionMask |= 0x10 217 | } 218 | 219 | let outLenght = UnsafeMutablePointer.allocate(capacity: 1) 220 | let outResult = UnsafeMutablePointer?>.allocate(capacity: 1) 221 | 222 | try safe { 223 | XGBoosterPredict( 224 | booster, 225 | data.dmatrix, 226 | optionMask, 227 | treeLimit, 228 | training ? 1 : 0, 229 | outLenght, 230 | outResult 231 | ) 232 | } 233 | 234 | let predictions = (0 ..< Int(outLenght.pointee)).map { outResult.pointee![$0] } 235 | let rowCount = try data.rowCount() 236 | let columnCount = try data.columnCount() 237 | var shape = Shape(rowCount) 238 | 239 | if predictions.count != rowCount, predictions.count % rowCount == 0 { 240 | let chunkSize = predictions.count / rowCount 241 | 242 | if predictionInteractions { 243 | let nGroup = chunkSize / ((columnCount + 1) * (columnCount + 1)) 244 | 245 | if nGroup == 1 { 246 | shape = Shape(rowCount, columnCount + 1, columnCount + 1) 247 | } else { 248 | shape = Shape(rowCount, nGroup, columnCount + 1, columnCount + 1) 249 | } 250 | } else if predictionContributions { 251 | let nGroup = chunkSize / (columnCount + 1) 252 | 253 | if nGroup == 1 { 254 | shape = Shape(rowCount, columnCount + 1) 255 | } else { 256 | shape = Shape(rowCount, nGroup, columnCount + 1) 257 | } 258 | } else { 259 | shape = Shape(rowCount, chunkSize) 260 | } 261 | } 262 | 263 | return ArrayWithShape(predictions, shape: shape) 264 | } 265 | 266 | /// Predict directly from array of floats, will build Data structure automatically with one row and features.count features. 267 | /// 268 | /// - Parameter features: Features to base prediction at. 269 | /// - Parameter outputMargin: Whether to output the raw untransformed margin value. 270 | /// - Parameter treeLimit: Limit number of trees in the prediction. Zero means use all trees. 271 | /// - Parameter predictionLeaf: Each record indicating the predicted leaf index of each sample in each tree. 272 | /// - Parameter predictionContributions: Each record indicating the feature contributions (SHAP values) for that prediction. 273 | /// - Parameter approximateContributions: Approximate the contributions of each feature. 274 | /// - Parameter predictionInteractions: Indicate the SHAP interaction values for each pair of features. 275 | /// - Parameter training: Whether the prediction will be used for traning. 276 | /// - Parameter missingValue: Value in features representing missing values. 277 | public func predict( 278 | features: FloatData, 279 | outputMargin: Bool = false, 280 | treeLimit: UInt32 = 0, 281 | predictionLeaf: Bool = false, 282 | predictionContributions: Bool = false, 283 | approximateContributions: Bool = false, 284 | predictionInteractions: Bool = false, 285 | training: Bool = false, 286 | missingValue: Float = Float.greatestFiniteMagnitude 287 | ) throws -> Float { 288 | let features = try features.data() 289 | return try predict( 290 | from: DMatrix( 291 | name: "predict", 292 | from: features, 293 | shape: Shape(1, features.count), 294 | features: self.features, 295 | missingValue: missingValue 296 | ), 297 | outputMargin: outputMargin, 298 | treeLimit: treeLimit, 299 | predictionLeaf: predictionLeaf, 300 | predictionContributions: predictionContributions, 301 | approximateContributions: approximateContributions, 302 | predictionInteractions: predictionInteractions, 303 | training: training 304 | )[0] 305 | } 306 | 307 | /// - Returns: Everything states in buffer. 308 | public func serialized() throws -> SerializedBuffer { 309 | let length = UnsafeMutablePointer.allocate(capacity: 1) 310 | let data = UnsafeMutablePointer?>.allocate(capacity: 1) 311 | 312 | try safe { 313 | XGBoosterSerializeToBuffer( 314 | booster, 315 | length, 316 | data 317 | ) 318 | } 319 | 320 | return ( 321 | length: length.pointee, 322 | data: data 323 | ) 324 | } 325 | 326 | /// Saves modes into file. 327 | /// 328 | /// - Parameter to: Path to output file. 329 | public func save( 330 | to path: String 331 | ) throws { 332 | try safe { 333 | XGBoosterSaveModel(booster, path) 334 | } 335 | } 336 | 337 | /// - Returns: Model as binary raw bytes. 338 | public func raw() throws -> RawModel { 339 | let length = UnsafeMutablePointer.allocate(capacity: 1) 340 | let data = UnsafeMutablePointer?>.allocate(capacity: 1) 341 | 342 | try safe { 343 | XGBoosterGetModelRaw( 344 | booster, 345 | length, 346 | data 347 | ) 348 | } 349 | 350 | return ( 351 | length: length.pointee, 352 | data: data 353 | ) 354 | } 355 | 356 | func formatModelDump( 357 | models: [String], 358 | format: ModelFormat 359 | ) -> String { 360 | var output: String 361 | 362 | switch format { 363 | case .json: 364 | output = "[\n\(models.joined(separator: ",\n"))]\n" 365 | case .text, .dot: 366 | output = "" 367 | for (index, booster) in models.enumerated() { 368 | output += "booster[\(index)]:\n\(booster)" 369 | } 370 | } 371 | 372 | return output 373 | } 374 | 375 | /// Dump model into a string. 376 | /// 377 | /// - Parameter features: Array of features. 378 | /// - Parameter featureMap: Name of the file containing feature map. 379 | /// - Parameter withStatistics: Controls whether the split statistics are output. 380 | /// - Parameter format: Desired output format type. 381 | /// - Returns: Formated output into ModelFormat format. 382 | public func dumped( 383 | features: [Feature]? = nil, 384 | featureMap: String = "", 385 | withStatistics: Bool = false, 386 | format: ModelFormat = .text 387 | ) throws -> String { 388 | try formatModelDump( 389 | models: rawDumped( 390 | features: features, 391 | featureMap: featureMap, 392 | withStatistics: withStatistics, 393 | format: format 394 | ), 395 | format: format 396 | ) 397 | } 398 | 399 | /// Dump model into an array of strings. 400 | /// In most cases you will want to use `dumped` method to get output in expected format. 401 | /// 402 | /// - Parameter features: Array of features, you can override ones stored in self.features with this. 403 | /// - Parameter featureMap: Name of the file containing feature map. 404 | /// - Parameter withStatistics: Controls whether the split statistics are output. 405 | /// - Parameter format: Desired output format type. 406 | /// - Returns: Raw output from XGBoosterDumpModelEx provided as array of strings. 407 | public func rawDumped( 408 | features: [Feature]? = nil, 409 | featureMap: String = "", 410 | withStatistics: Bool = false, 411 | format: ModelFormat = .text 412 | ) throws -> [String] { 413 | let features = features ?? self.features 414 | 415 | let outLenght = UnsafeMutablePointer.allocate(capacity: 1) 416 | let outResult = UnsafeMutablePointer?>?>.allocate(capacity: 1) 417 | 418 | if let boosterFeatures = features, featureMap == "" { 419 | var names = boosterFeatures.map(\.name.cCompatible) 420 | var types = boosterFeatures.map(\.type.rawValue.cCompatible) 421 | 422 | try safe { 423 | XGBoosterDumpModelExWithFeatures( 424 | booster, 425 | Int32(names.count), 426 | &names, 427 | &types, 428 | withStatistics ? 1 : 0, 429 | format.rawValue, 430 | outLenght, 431 | outResult 432 | ) 433 | } 434 | } else { 435 | if featureMap != "", !FileManager.default.fileExists(atPath: featureMap) { 436 | throw ValueError.runtimeError("File \(featureMap) does not exists.") 437 | } 438 | 439 | try safe { 440 | XGBoosterDumpModelEx( 441 | booster, 442 | featureMap, 443 | withStatistics ? 1 : 0, 444 | format.rawValue, 445 | outLenght, 446 | outResult 447 | ) 448 | } 449 | } 450 | 451 | return (0 ..< Int(outLenght.pointee)).map { String(cString: outResult.pointee![$0]!) } 452 | } 453 | 454 | /// Get feature importance of each feature. 455 | /// 456 | /// - Parameter featureMap: Path to the feature map. 457 | /// - Parameter importance: Type of importance you want to compute. 458 | /// - Returns: Tuple of features and gains, in case importance = weight, gains will be nil. 459 | public func score( 460 | featureMap: String = "", 461 | importance: Importance = .weight 462 | ) throws -> (features: [String: Int], gains: [String: Float]?) { 463 | if type != nil, ![.gbtree, .dart].contains(type!) { 464 | throw ValueError.runtimeError("Feature importance not defined for \(type!).") 465 | } 466 | 467 | if importance == .weight { 468 | var fMap = [String: Int]() 469 | let trees = try rawDumped( 470 | featureMap: featureMap, 471 | withStatistics: false, 472 | format: .text 473 | ) 474 | 475 | for tree in trees { 476 | let lines = tree.components(separatedBy: "\n") 477 | for line in lines { 478 | let splitted = line.components(separatedBy: "[") 479 | 480 | if splitted.count == 1 { 481 | // Leaf node 482 | continue 483 | } 484 | 485 | let fid = splitted[1] 486 | .components(separatedBy: "]")[0] 487 | .components(separatedBy: "<")[0] 488 | 489 | fMap[fid] = (fMap[fid] ?? 0) + 1 490 | } 491 | } 492 | 493 | return (fMap, nil) 494 | } 495 | 496 | var importance = importance 497 | var averageOverSplits = true 498 | 499 | if importance == .totalGain { 500 | importance = .gain 501 | averageOverSplits = false 502 | } 503 | 504 | if importance == .totalCover { 505 | importance = .cover 506 | averageOverSplits = false 507 | } 508 | 509 | var fMap = [String: Int]() 510 | var gMap = [String: Float]() 511 | let importanceSeparator = importance.rawValue + "=" 512 | let trees = try rawDumped( 513 | featureMap: featureMap, 514 | withStatistics: true, 515 | format: .text 516 | ) 517 | 518 | for tree in trees { 519 | let lines = tree.components(separatedBy: "\n") 520 | for line in lines { 521 | let splitted = line.components(separatedBy: "[") 522 | 523 | if splitted.count == 1 { 524 | // Leaf node 525 | continue 526 | } 527 | 528 | let feature = splitted[1].components(separatedBy: "]") 529 | let featureName = feature[0].components(separatedBy: "<")[0] 530 | let gain = Float( 531 | feature[1] 532 | .components(separatedBy: importanceSeparator)[1] 533 | .components(separatedBy: ",")[0] 534 | )! 535 | 536 | fMap[featureName] = (fMap[featureName] ?? 0) + 1 537 | gMap[featureName] = (gMap[featureName] ?? 0) + gain 538 | } 539 | } 540 | 541 | if averageOverSplits { 542 | for (featureName, value) in gMap { 543 | gMap[featureName] = value / Float(fMap[featureName]!) 544 | } 545 | } 546 | 547 | return (fMap, gMap) 548 | } 549 | 550 | /// Loads model from buffer. 551 | /// 552 | /// - Parameter model: Buffer to load from. 553 | public func load( 554 | modelBuffer buffer: BufferModel 555 | ) throws { 556 | try safe { 557 | XGBoosterLoadModelFromBuffer( 558 | booster, 559 | buffer.data, 560 | buffer.length 561 | ) 562 | } 563 | } 564 | 565 | /// Loads model from file. 566 | /// 567 | /// - Parameter model: Path of file to load model from. 568 | public func load( 569 | model path: String 570 | ) throws { 571 | try safe { 572 | XGBoosterLoadModel(booster, path) 573 | } 574 | } 575 | 576 | /// Loads model from config. 577 | /// 578 | /// - Parameter model: Config to load from. 579 | public func load( 580 | config: String 581 | ) throws { 582 | try safe { 583 | XGBoosterLoadJsonConfig(booster, config) 584 | } 585 | } 586 | 587 | /// Save the current checkpoint to rabit. 588 | public func saveRabitCheckpoint() throws { 589 | try safe { 590 | XGBoosterSaveRabitCheckpoint(booster) 591 | } 592 | } 593 | 594 | /// Initialize the booster from rabit checkpoint. 595 | /// 596 | /// - Returns: The output version of the model. 597 | public func loadRabitCheckpoint() throws -> Int { 598 | var version: Int32 = -1 599 | try safe { 600 | XGBoosterLoadRabitCheckpoint(booster, &version) 601 | } 602 | return Int(version) 603 | } 604 | 605 | /// Get attribute string from the Booster. 606 | /// 607 | /// - Parameter name: Name of attribute to get. 608 | /// - Returns: Value of attribute or nil if not set. 609 | public func attribute( 610 | name: String 611 | ) throws -> String? { 612 | var success: Int32 = -1 613 | let outResult = UnsafeMutablePointer?>.allocate(capacity: 1) 614 | 615 | try safe { 616 | XGBoosterGetAttr(booster, name, outResult, &success) 617 | } 618 | 619 | if success != 1 { 620 | return nil 621 | } 622 | 623 | return String(cString: outResult.pointee!) 624 | } 625 | 626 | /// Set string attribute. 627 | /// 628 | /// - Parameter attribute: Name of attribute. 629 | /// - Parameter value: Value of attribute. 630 | public func set( 631 | attribute: String, 632 | value: String 633 | ) throws { 634 | try safe { 635 | XGBoosterSetAttr( 636 | booster, attribute, value 637 | ) 638 | } 639 | } 640 | 641 | /// Set string parameter. 642 | /// 643 | /// - Parameter parameter: Name of parameter. 644 | /// - Parameter value: Value of parameter. 645 | public func set( 646 | parameter: String, 647 | value: String 648 | ) throws { 649 | try safe { 650 | XGBoosterSetParam(booster, parameter, value) 651 | } 652 | } 653 | 654 | /// Update for one iteration, with objective function calculated internally. 655 | /// 656 | /// - Parameter iteration: Current iteration number. 657 | /// - Parameter data: Training data. 658 | /// - Parameter validateFeatures: Whether to validate features. 659 | public func update( 660 | iteration: Int, 661 | data: Data, 662 | validateFeatures: Bool = true 663 | ) throws { 664 | if validateFeatures { 665 | try validate(data: data) 666 | } 667 | 668 | try safe { 669 | XGBoosterUpdateOneIter( 670 | booster, 671 | Int32(iteration), 672 | data.dmatrix 673 | ) 674 | } 675 | } 676 | 677 | /// Update for one iteration with custom objective. 678 | /// 679 | /// - Parameter data: Training data. 680 | /// - Parameter objective: Objective function returning gradient and hessian. 681 | /// - Parameter validateFeatures: Whether to validate features. 682 | public func update( 683 | data: Data, 684 | objective: ObjectiveFunction, 685 | validateFeatures: Bool = true 686 | ) throws { 687 | let predicted = try predict( 688 | from: data, 689 | outputMargin: true, 690 | training: true, 691 | validateFeatures: validateFeatures 692 | ) 693 | let (gradient, hessian) = try objective(predicted, data) 694 | try boost( 695 | data: data, 696 | gradient: gradient, 697 | hessian: hessian, 698 | validateFeatures: false 699 | ) 700 | } 701 | 702 | /// Boost the booster for one iteration, with customized gradient statistics. 703 | /// 704 | /// - Parameter data: Training data. 705 | /// - Parameter gradient: The first order of gradient. 706 | /// - Parameter hessian: The second order of gradient. 707 | /// - Parameter validateFeatures: Whether to validate features. 708 | public func boost( 709 | data: Data, 710 | gradient: [Float], 711 | hessian: [Float], 712 | validateFeatures: Bool = true 713 | ) throws { 714 | if gradient.count != hessian.count { 715 | throw ValueError.runtimeError( 716 | "Gradient count \(gradient.count) != Hessian count \(hessian.count)." 717 | ) 718 | } 719 | 720 | if validateFeatures { 721 | try validate(data: data) 722 | } 723 | 724 | var gradient = gradient 725 | var hessian = hessian 726 | 727 | try safe { 728 | XGBoosterBoostOneIter( 729 | booster, 730 | data.dmatrix, 731 | &gradient, 732 | &hessian, 733 | UInt64(gradient.count) 734 | ) 735 | } 736 | } 737 | 738 | /// Evaluate array of data. 739 | /// 740 | /// - Parameter iteration: Current iteration. 741 | /// - Parameter data: Data to evaluate. 742 | /// - Parameter function: Custom function for evaluation. 743 | /// - Returns: Dictionary in format [data_name: [eval_name: eval_value, ...], ...] 744 | public func evaluate( 745 | iteration: Int, 746 | data: [Data], 747 | function: EvaluationFunction? = nil 748 | ) throws -> [String: [String: String]] { 749 | try validate(data: data) 750 | 751 | var pointees = data.map(\.dmatrix) 752 | var names = data.map(\.name.cCompatible) 753 | let output = UnsafeMutablePointer?>.allocate(capacity: 1) 754 | 755 | try safe { 756 | XGBoosterEvalOneIter( 757 | booster, 758 | Int32(iteration), 759 | &pointees, 760 | &names, 761 | UInt64(data.count), 762 | output 763 | ) 764 | } 765 | 766 | var results = [String: [String: String]]() 767 | let outputString = String(cString: output.pointee!) 768 | 769 | for (index, result) in outputString.components(separatedBy: .whitespacesAndNewlines).enumerated() { 770 | if index == 0 { 771 | continue 772 | } 773 | 774 | let resultSplitted = result.components(separatedBy: ":") 775 | 776 | let nameSplitted = resultSplitted[0].components(separatedBy: "-") 777 | let name = nameSplitted[0 ..< nameSplitted.count - 1].joined(separator: "-") 778 | let metric = nameSplitted.last! 779 | 780 | let value = resultSplitted[1] 781 | 782 | if results[name] == nil { 783 | results[name] = [:] 784 | } 785 | 786 | results[name]![metric] = value 787 | } 788 | 789 | if let function = function { 790 | for data in data { 791 | let (name, result) = try function( 792 | try predict(from: data, training: false), 793 | data 794 | ) 795 | 796 | results[data.name]![name] = result 797 | } 798 | } 799 | 800 | return results 801 | } 802 | 803 | /// Evaluate data. 804 | /// 805 | /// - Parameter iteration: Current iteration. 806 | /// - Parameter data: Data to evaluate. 807 | /// - Parameter function: Custom function for evaluation. 808 | /// - Returns: Dictionary in format [data_name: [eval_name: eval_value]] 809 | public func evaluate( 810 | iteration: Int, 811 | data: Data, 812 | function: EvaluationFunction? = nil 813 | ) throws -> [String: String] { 814 | try evaluate(iteration: iteration, data: [data], function: function)[data.name]! 815 | } 816 | 817 | /// Validate features. 818 | /// 819 | /// - Parameter features: Features to validate. 820 | public func validate( 821 | features: [Feature] 822 | ) throws { 823 | guard let boosterFeatures = self.features else { 824 | self.features = features 825 | return 826 | } 827 | 828 | let featureNames = features.map(\.name) 829 | let boosterFeatureNames = boosterFeatures.map(\.name) 830 | 831 | if featureNames != boosterFeatureNames { 832 | let dataMissing = Set(boosterFeatureNames).subtracting(Set(featureNames)) 833 | let boosterMissing = Set(featureNames).subtracting(Set(boosterFeatureNames)) 834 | 835 | throw ValueError.runtimeError(""" 836 | Feature names mismatch. 837 | Missing in data: \(dataMissing). 838 | Missing in booster: \(boosterMissing). 839 | """) 840 | } 841 | } 842 | 843 | /// Validate features. 844 | /// 845 | /// - Parameter data: Data which features will be validated. 846 | public func validate( 847 | data: DMatrix 848 | ) throws { 849 | try validate(features: data.features()) 850 | } 851 | 852 | /// Validate features. 853 | /// 854 | /// - Parameter data: Array of data which features will be validated. 855 | public func validate( 856 | data: [DMatrix] 857 | ) throws { 858 | for data in data { 859 | try validate(features: data.features()) 860 | } 861 | } 862 | } 863 | --------------------------------------------------------------------------------