├── .gitignore
├── .swiftpm
└── xcode
│ └── package.xcworkspace
│ └── xcshareddata
│ └── IDEWorkspaceChecks.plist
├── LICENSE
├── Package.resolved
├── Package.swift
├── README.md
├── Sources
├── AgeKit
│ ├── Age.swift
│ ├── Armor
│ │ └── Armor.swift
│ ├── Extensions
│ │ ├── ByteBuffer.swift
│ │ ├── ChaChaPoly+properties.swift
│ │ ├── Curve25519+properties.swift
│ │ ├── Data+ReadLine.swift
│ │ ├── OutputStream.swift
│ │ └── bytes.swift
│ ├── Format
│ │ └── Format.swift
│ ├── Parse.swift
│ ├── Primitives.swift
│ ├── Scrypt.swift
│ ├── Stream
│ │ └── Stream.swift
│ └── X25519.swift
└── Bech32
│ └── Bech32.swift
├── Tests
├── AgeKitTests
│ ├── AgeKitTests.swift
│ ├── ArmorTests
│ │ └── ArmorTests.swift
│ ├── FormatTests
│ │ └── FormatTests.swift
│ └── StreamTests
│ │ ├── StreamTests.swift
│ │ ├── StreamTests.swift.gyb
│ │ └── StreamTests_gen.swift
└── Bech32Tests
│ └── Bech32Tests.swift
├── gyb
└── gyb.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | /.build
3 | /Packages
4 | /*.xcodeproj
5 | xcuserdata/
6 | DerivedData/
7 | .swiftpm/config/registries.json
8 | .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
9 | .netrc
10 |
--------------------------------------------------------------------------------
/.swiftpm/xcode/package.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | IDEDidComputeMac32BitWarning
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2023, James O'Gorman
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/Package.resolved:
--------------------------------------------------------------------------------
1 | {
2 | "pins" : [
3 | {
4 | "identity" : "swift-asn1",
5 | "kind" : "remoteSourceControl",
6 | "location" : "https://github.com/apple/swift-asn1.git",
7 | "state" : {
8 | "branch" : "main",
9 | "revision" : "155b8acdb2a9be1db8a195b2f6a9e6c18d9b754e"
10 | }
11 | },
12 | {
13 | "identity" : "swift-atomics",
14 | "kind" : "remoteSourceControl",
15 | "location" : "https://github.com/apple/swift-atomics.git",
16 | "state" : {
17 | "revision" : "ff3d2212b6b093db7f177d0855adbc4ef9c5f036",
18 | "version" : "1.0.3"
19 | }
20 | },
21 | {
22 | "identity" : "swift-collections",
23 | "kind" : "remoteSourceControl",
24 | "location" : "https://github.com/apple/swift-collections.git",
25 | "state" : {
26 | "revision" : "937e904258d22af6e447a0b72c0bc67583ef64a2",
27 | "version" : "1.0.4"
28 | }
29 | },
30 | {
31 | "identity" : "swift-extras-base64",
32 | "kind" : "remoteSourceControl",
33 | "location" : "https://github.com/swift-extras/swift-extras-base64.git",
34 | "state" : {
35 | "revision" : "97237cf1bc1feebaeb0febec91c1e1b9e4d839b3",
36 | "version" : "0.7.0"
37 | }
38 | },
39 | {
40 | "identity" : "swift-nio",
41 | "kind" : "remoteSourceControl",
42 | "location" : "https://github.com/apple/swift-nio.git",
43 | "state" : {
44 | "revision" : "45167b8006448c79dda4b7bd604e07a034c15c49",
45 | "version" : "2.48.0"
46 | }
47 | },
48 | {
49 | "identity" : "swift-scrypt",
50 | "kind" : "remoteSourceControl",
51 | "location" : "https://github.com/greymass/swift-scrypt.git",
52 | "state" : {
53 | "revision" : "631f21c36bff63e33ad13353ee801b4a032dda15",
54 | "version" : "1.0.2"
55 | }
56 | }
57 | ],
58 | "version" : 2
59 | }
60 |
--------------------------------------------------------------------------------
/Package.swift:
--------------------------------------------------------------------------------
1 | // swift-tools-version: 5.7
2 | // The swift-tools-version declares the minimum version of Swift required to build this package.
3 |
4 | import PackageDescription
5 |
6 | let package = Package(
7 | name: "AgeKit",
8 | // Parts of CryptoKit require macOS 11+, iOS 14+
9 | platforms: [
10 | .macOS(.v11),
11 | .iOS(.v14)
12 | ],
13 | products: [
14 | .library(
15 | name: "AgeKit",
16 | targets: ["AgeKit"]),
17 | ],
18 | dependencies: [
19 | .package(url: "https://github.com/greymass/swift-scrypt.git", from: "1.0.0"),
20 | .package(url: "https://github.com/apple/swift-nio.git", from: "2.0.0"),
21 | .package(url: "https://github.com/swift-extras/swift-extras-base64.git", from: "0.7.0"),
22 | // TODO: Change this to a version once they've tagged a release which includes 155b8acdb2a9be1db8a195b2f6a9e6c18d9b754e
23 | .package(url: "https://github.com/apple/swift-asn1.git", branch: "main"),
24 | ],
25 | targets: [
26 | .target(
27 | name: "AgeKit",
28 | dependencies: [
29 | "Bech32",
30 | .product(name: "Scrypt", package: "swift-scrypt"),
31 | .product(name: "NIO", package: "swift-nio"),
32 | .product(name: "ExtrasBase64", package: "swift-extras-base64")
33 | ]),
34 | .target(name: "Bech32"),
35 | .testTarget(
36 | name: "Bech32Tests",
37 | dependencies: ["Bech32"]),
38 | .testTarget(
39 | name: "AgeKitTests",
40 | dependencies: ["AgeKit", .product(name: "SwiftASN1", package: "swift-asn1")],
41 | exclude: [
42 | "StreamTests/StreamTests.swift.gyb",
43 | ]),
44 | ]
45 | )
46 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 |
3 | # AgeKit: Swift implementation of age
4 |
5 | [age](https://age-encryption.org) is a simple, modern and secure file encryption tool and format. It features small explicit keys, no config options, and UNIX-style composability.
6 |
7 | AgeKit provides a Swift implementation of the library and format.
8 |
9 | The reference Go implementation is available at [filippo.io/age](https://filippo.io/age).
10 |
11 | ## Implementation Notes
12 |
13 | These features of age have been implemented:
14 |
15 | - ✅ X25519 public/private keys
16 | - ⚠️ Scrypt passphrases: this is implemented but there are issues with the format that have
17 | compatibility issues with the Go version of the tool
18 | - ✅ Armored (PEM) encoding and decoding
19 | - ❌ SSH keys
20 | - ❌ GitHub users
21 |
22 | AgeKit uses features from CryptoKit and needs at least macOS 11 or iOS 14.
23 |
24 | ## Getting Started
25 |
26 | To use AgeKit add the following dependency to `Package.swift`:
27 |
28 | ```swift
29 | dependencies: [
30 | .package(url: "https://github.com/jamesog/AgeKit.git", branch: "main"),
31 | ]
32 | ```
33 |
34 | You can then add the dependency to your target:
35 |
36 | ```swift
37 | dependencies: [
38 | "AgeKit",
39 | ]
40 | ```
41 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Age.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import CryptoKit
3 |
4 | /// An Identity is passed to `decrypt` to unwrap an opaque file key from a
5 | /// recipient stanza. It can be for example a secret key like X25519Identity,
6 | /// a plugin, or a custom implementation.
7 | ///
8 | /// Most age API users won't need to interact with this directly, and should
9 | /// instead pass implementations conforming to`Recipient` to `encrypt`
10 | /// and implementations conforming to `Identity` to `decrypt`.
11 | public protocol Identity {
12 | func unwrap(stanzas: [Age.Stanza]) throws -> SymmetricKey
13 | }
14 |
15 | /// A Recipient is passed to `encrypt` to wrap an opaque file key to one or more
16 | /// recipient stanza(s). It can be for example a public key like X25519Recipient.
17 | /// a plugin, or a custom implementation.
18 | ///
19 | /// Most age API users won't need to interact with this directly, and should
20 | /// instead pass Recipient implementations to `encrypt` and implementations
21 | /// conforming to `Identity`.
22 | public protocol Recipient {
23 | func wrap(fileKey: SymmetricKey) throws -> [Age.Stanza]
24 | }
25 |
26 | // MARK: -
27 |
28 | public enum Age {
29 | static let fileKeySize = 16
30 | static let streamNonceSize = 16
31 |
32 |
33 | /// A Stanza is a section of the age header that encapsulates the file key as
34 | /// encrypted to a specific recipient.
35 | ///
36 | /// Most age API users won't need to interact with this directly, and should
37 | /// instead pass Recipient implementations to `encrypt` and implementations
38 | /// conforming to `Identity`.
39 | public struct Stanza {
40 | var type: String
41 | var args: [String]
42 | var body: Data
43 |
44 | init() {
45 | self.type = ""
46 | self.args = []
47 | self.body = Data()
48 | }
49 |
50 | init(type: String, args: [String], body: Data?) {
51 | self.type = type
52 | self.args = args
53 | self.body = body ?? Data()
54 | }
55 |
56 | init(_ s: Format.Stanza) {
57 | self.type = s.type
58 | self.args = s.args
59 | self.body = s.body
60 | }
61 | }
62 | }
63 |
64 | // MARK: - Encrypt
65 |
66 | extension Age {
67 | public enum EncryptError: Error {
68 | case noRecipients
69 | case scryptRecipientMustBeOnlyOne
70 | case nonceGeneration
71 | case nonceWrite(Error)
72 | }
73 |
74 | /// Encrypt a file to one or more recipients.
75 | ///
76 | /// Writes to the returned `StreamWriter` are encrypted and written to `dst` as an age file.
77 | /// Every recipient will be able to decrypt the file.
78 | ///
79 | /// The caller must call `close()` on the `StreamWriter` when done for the last chunk to
80 | /// be encrypted and flushed to `dst`.
81 | public static func encrypt(dst: inout OutputStream, recipients: Recipient...) throws -> StreamWriter {
82 | guard !recipients.isEmpty else {
83 | throw EncryptError.noRecipients
84 | }
85 |
86 | for r in recipients {
87 | if ((r as? ScryptRecipient) != nil) && recipients.count != 1 {
88 | throw EncryptError.scryptRecipientMustBeOnlyOne
89 | }
90 | }
91 |
92 | let fileKey = SymmetricKey(size: .bits128)
93 |
94 | var hdr = Format.Header()
95 | for r in recipients {
96 | let stanzas = try r.wrap(fileKey: fileKey)
97 | for s in stanzas {
98 | hdr.recipients.append(Format.Stanza(s))
99 | }
100 | }
101 | hdr.mac = try headerMAC(fileKey: fileKey, hdr: hdr)
102 | try hdr.encode(to: &dst)
103 |
104 | var nonce = [UInt8](repeating: 0, count: streamNonceSize)
105 | let status = SecRandomCopyBytes(kSecRandomDefault, nonce.count, &nonce)
106 | guard status == errSecSuccess else {
107 | throw EncryptError.nonceGeneration
108 | }
109 |
110 | _ = try dst.write(Data(nonce))
111 | if let streamError = dst.streamError {
112 | throw EncryptError.nonceWrite(streamError)
113 | }
114 |
115 | return StreamWriter(fileKey: streamKey(fileKey: fileKey, nonce: nonce), dst: dst)
116 | }
117 | }
118 |
119 | // MARK: - Decrypt
120 |
121 | extension Age {
122 | public enum DecryptError: Error {
123 | case incorrectIdentity
124 | case noIdentities
125 | case computingHeaderMAC
126 | case badHeaderMAC
127 | case nonceRead(Error)
128 | }
129 |
130 | /// Decrypts a file encrypted to one or more identities.
131 | ///
132 | /// It returns a `StreamReader` for reading the decrypted plaintext of the age file read
133 | /// from `src`. All identities will be tried until one successfully decrypts the file.
134 | public static func decrypt(src: InputStream, identities: Identity...) throws -> StreamReader {
135 | if identities.isEmpty {
136 | throw DecryptError.noIdentities
137 | }
138 |
139 | let (hdr, payload) = try Format.parse(input: src)
140 | var stanzas: [Stanza] = []
141 | for s in hdr.recipients {
142 | stanzas.append(Stanza(s))
143 | }
144 | var fileKey: SymmetricKey?
145 | for id in identities {
146 | do {
147 | fileKey = try id.unwrap(stanzas: stanzas)
148 | } catch DecryptError.incorrectIdentity {
149 | continue
150 | } catch {
151 | throw error
152 | }
153 | break
154 | }
155 | guard let fileKey else {
156 | throw DecryptError.noIdentities
157 | }
158 |
159 | let mac = try headerMAC(fileKey: fileKey, hdr: hdr)
160 | if mac != hdr.mac {
161 | throw DecryptError.badHeaderMAC
162 | }
163 |
164 | var nonce = [UInt8](repeating: 0, count: streamNonceSize)
165 | payload.read(&nonce, maxLength: streamNonceSize)
166 |
167 | return StreamReader(
168 | fileKey: streamKey(fileKey: fileKey, nonce: nonce),
169 | src: payload)
170 | }
171 |
172 | static func multiUnwrap(stanzas: [Stanza], unwrap: (Stanza) throws -> SymmetricKey) throws -> SymmetricKey {
173 | for stanza in stanzas {
174 | do {
175 | let fileKey = try unwrap(stanza)
176 | return fileKey
177 | } catch DecryptError.incorrectIdentity {
178 | continue
179 | } catch {
180 | throw error
181 | }
182 | }
183 |
184 | throw DecryptError.incorrectIdentity
185 | }
186 | }
187 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Armor/Armor.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import NIOCore
3 |
4 | public enum Armor {
5 | public static let header = "-----BEGIN AGE ENCRYPTED FILE-----"
6 | public static let footer = "-----END AGE ENCRYPTED FILE-----"
7 | }
8 |
9 | public enum ArmorError: Error {
10 | case writerAlreadyClosed
11 | case invalidHeader(String)
12 | case trailingDataAfterArmor
13 | case tooMuchTrailingWhitespace
14 | case base64DecodeError
15 | }
16 |
17 | // MARK: - Writer
18 |
19 | extension Armor {
20 | public struct Writer {
21 | let dst: OutputStream
22 | private var started = false
23 | private var written = 0
24 |
25 | public init(dst: OutputStream) {
26 | self.dst = dst
27 | }
28 |
29 | public mutating func write(_ data: Data) throws -> Int {
30 | if !started {
31 | _ = try dst.write(Armor.header + "\n")
32 | }
33 | started = true
34 | written += try dst.write(
35 | data.base64EncodedString(options: [.lineLength64Characters, .endLineWithLineFeed])
36 | )
37 | written += try dst.write("\n")
38 | return written
39 | }
40 |
41 | public func close() throws {
42 | if dst.streamStatus == .closed {
43 | throw ArmorError.writerAlreadyClosed
44 | }
45 | _ = try dst.write(Armor.footer)
46 | }
47 | }
48 | }
49 |
50 | // MARK: - Reader
51 |
52 | extension Armor {
53 | public struct Reader {
54 | private var src: ByteBuffer
55 | private var encoded = ""
56 | private var started = false
57 | private let maxWhitespace = 1024
58 |
59 | public init(src: InputStream) {
60 | // FIXME: Consuming the entire input is probably bad, but InputStream is too hard to work with.
61 | self.src = ByteBuffer(src)
62 | }
63 |
64 | public mutating func read(_ buffer: inout [UInt8]) throws -> Int {
65 | var read = 0
66 | if !started {
67 | guard let line = src.readString(until: "\n") else {
68 | return -1
69 | }
70 | let header = line.trimmingCharacters(in: ["\n"])
71 | if header != Armor.header {
72 | throw ArmorError.invalidHeader(header)
73 | }
74 | started = true
75 | }
76 |
77 | while src.readableBytes > 0 {
78 | guard let line = src.readString(until: "\n") else {
79 | return -1
80 | }
81 | if line == Armor.footer {
82 | break
83 | }
84 | encoded += line
85 | read += line.count
86 | }
87 | if src.readableBytes > 0 {
88 | let trailing = src.readString(length: min(src.readableBytes, maxWhitespace))
89 | guard let trailing, trailing.trimmingCharacters(in: .whitespaces).isEmpty else {
90 | throw ArmorError.trailingDataAfterArmor
91 | }
92 | guard trailing.count < maxWhitespace else {
93 | throw ArmorError.tooMuchTrailingWhitespace
94 | }
95 | }
96 | guard let enc = Data(base64Encoded: encoded, options: .ignoreUnknownCharacters) else {
97 | throw ArmorError.base64DecodeError
98 | }
99 |
100 | buffer.append(contentsOf: enc)
101 | return read
102 | }
103 | }
104 | }
105 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Extensions/ByteBuffer.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import NIOCore
3 |
4 | extension ByteBuffer {
5 | /// Create a fresh ByteBuffer from the given `InputStream`.
6 | ///
7 | /// The entire `InputStream` is consumed into the buffer.
8 | init(_ input: InputStream) {
9 | let bufSize = 4096
10 | self = ByteBufferAllocator().buffer(capacity: bufSize)
11 | var buf = [UInt8](repeating: 0, count: bufSize)
12 | while input.hasBytesAvailable {
13 | let result = input.read(&buf, maxLength: bufSize)
14 | if result < 0 {
15 | break
16 | }
17 | if result == 0 {
18 | break
19 | }
20 | self.writeBytes(buf[.. Array.Index? {
25 | var pos = self.readerIndex
26 | let bufSize = 4096
27 | let char = delim.asciiValue!
28 | while self.readableBytes > 0 {
29 | let length = (bufSize <= self.readableBytes ? bufSize : self.readableBytes)
30 | guard let bytes = self.getBytes(at: pos, length: length) else {
31 | return nil
32 | }
33 | guard let index = bytes.firstIndex(of: char) else {
34 | pos += bufSize
35 | continue
36 | }
37 | return index+1
38 | }
39 | return nil
40 | }
41 |
42 | /// Read the buffer until `delim` is found, move the reader index forward by the length of the data
43 | /// and return the result as `[UInt8]`, or `nil` if `delim` was not found.
44 | mutating func readBytes(until delim: Character) -> [UInt8]? {
45 | guard let index = self.indexOf(delim: delim) else {
46 | return self.readBytes(length: self.readableBytes)
47 | }
48 | return self.readBytes(length: index)
49 | }
50 |
51 | /// Read the buffer until `delim` is found, decoding is as String using the UTF-8 encoding.
52 | /// The reader index is moved forward by the length of the string found.
53 | mutating func readString(until delim: Character) -> String? {
54 | guard let index = self.indexOf(delim: delim) else {
55 | return self.readString(length: self.readableBytes)
56 | }
57 | return self.readString(length: index)
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Extensions/ChaChaPoly+properties.swift:
--------------------------------------------------------------------------------
1 | import CryptoKit
2 |
3 | extension ChaChaPoly {
4 | static let keySize = 32
5 | static let overhead = 16
6 | static let nonceSize = 12
7 | static let tagSize = 16
8 | }
9 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Extensions/Curve25519+properties.swift:
--------------------------------------------------------------------------------
1 | import CryptoKit
2 |
3 | extension Curve25519 {
4 | static let pointSize = 32
5 | }
6 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Extensions/Data+ReadLine.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | extension Data {
4 | func readLine(delimiter: String = "\n") -> Data? {
5 | let delimiter = delimiter.data(using: .utf8)!
6 | let delimiterRange = self.range(of: delimiter)
7 | let lineRange = 0.. Int {
5 | if s.isEmpty { return 0 }
6 | let bytes = [UInt8](s.utf8)
7 | let ret = write(bytes, maxLength: bytes.count)
8 | if let streamError = streamError {
9 | throw streamError
10 | }
11 | return ret
12 | }
13 |
14 | func write(_ d: Data) throws -> Int {
15 | if d.isEmpty { return 0 }
16 | let buf: [UInt8] = Array(d)
17 | let ret = write(buf, maxLength: buf.count)
18 | if let streamError = streamError {
19 | throw streamError
20 | }
21 | return ret
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Extensions/bytes.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | extension Data {
4 | var bytes: [UInt8] {
5 | Array(self)
6 | }
7 | }
8 |
9 | extension String {
10 | public var bytes: [UInt8] {
11 | data(using: String.Encoding.utf8, allowLossyConversion: true)?.bytes ?? Array(utf8)
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Format/Format.swift:
--------------------------------------------------------------------------------
1 | import CryptoKit
2 | import ExtrasBase64
3 | import Foundation
4 | import NIOCore
5 |
6 | public enum Format {
7 | static let intro = "age-encryption.org/v1\n"
8 |
9 | static let stanzaPrefix = "->".data(using: .utf8)!
10 | static let footerPrefix = "---".data(using: .utf8)!
11 | }
12 |
13 | // MARK: - Header
14 |
15 | extension Format {
16 | public struct Header {
17 | var recipients: [Stanza] = []
18 | var mac = Data()
19 |
20 | public func encodeWithoutMAC(to hash: inout HMAC) {
21 | hash.update(data: Format.intro.data(using: .utf8)!)
22 | for r in self.recipients {
23 | hash.update(data: r.encode())
24 | }
25 | hash.update(data: Format.footerPrefix)
26 | }
27 |
28 | public func encodeWithoutMAC(to output: inout OutputStream) throws {
29 | _ = try output.write(Format.intro.data(using: .utf8)!)
30 | for r in self.recipients {
31 | try r.encode(to: &output)
32 | }
33 | _ = try output.write(Format.footerPrefix)
34 | }
35 |
36 | public func encode(to output: inout OutputStream) throws {
37 | try self.encodeWithoutMAC(to: &output)
38 | _ = try output.write(" ".data(using: .utf8)!)
39 | let b64 = Base64.encodeString(bytes: self.mac, options: .omitPaddingCharacter)
40 | _ = try output.write(b64)
41 | _ = try output.write("\n".data(using: .utf8)!)
42 | }
43 | }
44 | }
45 |
46 | // MARK: - Stanza
47 |
48 | extension Format {
49 | static let columnsPerLine = 64
50 | static var bytesPerLine: Int { columnsPerLine / 4 * 3 }
51 |
52 | public struct Stanza {
53 | var type: String
54 | var args: [String]
55 | var body: Data
56 |
57 | init() {
58 | self.type = ""
59 | self.args = []
60 | self.body = Data()
61 | }
62 |
63 | init(_ s: Age.Stanza) {
64 | self.type = s.type
65 | self.args = s.args
66 | self.body = s.body
67 | }
68 |
69 | public func encode() -> Data {
70 | var out = OutputStream.toMemory()
71 | out.open()
72 | try! encode(to: &out)
73 | out.close()
74 | return out.property(forKey: .dataWrittenToMemoryStreamKey) as! Data
75 | }
76 |
77 | public func encode(to: inout OutputStream) throws {
78 | var stanza = String(data: Format.stanzaPrefix, encoding: .utf8)!
79 | var args = [type]
80 | args.append(contentsOf: self.args)
81 | for a in args {
82 | stanza.append(" \(a)")
83 | }
84 | stanza.append("\n")
85 | let b64 = Base64.encodeString(bytes: body, options: .omitPaddingCharacter)
86 | stanza.append(b64)
87 | // The format is a little finicky and requires some short lines.
88 | // When the input is divisible by bytesPerLine the encoder won't have
89 | // added the final newline the format expects.
90 | if self.body.count > 0 && self.body.count % Format.bytesPerLine == 0 {
91 | stanza.append("\n")
92 | }
93 | stanza.append("\n")
94 | _ = try to.write(stanza.data(using: .utf8)!)
95 | }
96 | }
97 |
98 | enum DecodeError: Error {
99 | case unexpectedNewLineError
100 | }
101 |
102 | public static func decodeString(_ s: String) throws -> Data {
103 | if #available(macOS 13.0, iOS 16.0, *) {
104 | if s.contains(["\n", "\r"]) {
105 | throw DecodeError.unexpectedNewLineError
106 | }
107 | } else {
108 | if s.contains("\n") || s.contains("\r") {
109 | throw DecodeError.unexpectedNewLineError
110 | }
111 | }
112 | let b64 = try Base64.decode(string: s, options: .omitPaddingCharacter)
113 | return Data(b64)
114 | }
115 |
116 | enum StanzaError: Error {
117 | case lineError
118 | case malformedOpeningLine
119 | case malformedStanza
120 | case malformedBodyLineSize
121 | }
122 |
123 | struct StanzaReader {
124 | var buf: ByteBuffer
125 |
126 | init(_ buf: ByteBuffer) {
127 | self.buf = buf
128 | }
129 |
130 | public mutating func readStanza() throws -> Stanza {
131 | var stanza = Stanza()
132 |
133 | guard let line = buf.readBytes(until: "\n") else {
134 | throw StanzaError.lineError
135 | }
136 | guard line.starts(with: stanzaPrefix) else {
137 | throw StanzaError.malformedOpeningLine
138 | }
139 | let (prefix, args) = splitArgs(line: line)
140 | guard prefix.bytes == stanzaPrefix.bytes && args.count >= 1 else {
141 | throw StanzaError.malformedStanza
142 | }
143 | for arg in args where !isValidString(arg) {
144 | throw StanzaError.malformedStanza
145 | }
146 | stanza.type = args[0]
147 | stanza.args = Array(args[1...])
148 |
149 | while true {
150 | guard let line = buf.readBytes(until: "\n") else {
151 | throw StanzaError.lineError
152 | }
153 |
154 | var lineStr = String(bytes: line, encoding: .utf8)!
155 | lineStr = lineStr.trimmingCharacters(in: ["\n"])
156 | let b: Data
157 | do {
158 | b = try decodeString(lineStr)
159 | if b.count > bytesPerLine {
160 | throw StanzaError.malformedBodyLineSize
161 | }
162 | stanza.body.append(b)
163 | if b.count < bytesPerLine {
164 | return stanza
165 | }
166 | } catch {
167 | // TODO: The Go implementation checks the value for the footerPrefix and stanzaPrefix
168 | }
169 | }
170 | }
171 | }
172 |
173 | enum ParseError: Error {
174 | case introRead
175 | case unexpectedIntro
176 | case readHeader
177 | case malformedClosingLine
178 | case internalError
179 | }
180 |
181 | public static func parse(input: InputStream) throws -> (Header, InputStream) {
182 | var header = Header()
183 | // Consume the entire input
184 | // FIXME: We shouldn't do this and should read chunks at a time
185 | var buf = ByteBuffer(input)
186 |
187 | guard let line = buf.readString(until: "\n") else {
188 | throw ParseError.introRead
189 | }
190 | guard line == Format.intro else {
191 | throw ParseError.unexpectedIntro
192 | }
193 |
194 | while true {
195 | guard let peek = buf.getBytes(at: buf.readerIndex, length: footerPrefix.count) else {
196 | throw ParseError.readHeader
197 | }
198 | if peek == Array(footerPrefix) {
199 | guard let line = buf.readBytes(until: "\n") else {
200 | throw ParseError.readHeader
201 | }
202 |
203 | let (prefix, args) = splitArgs(line: line)
204 | if prefix != String(data: footerPrefix, encoding: .utf8)! || args.count != 1 {
205 | throw ParseError.malformedClosingLine
206 | }
207 | header.mac = try decodeString(args[0])
208 | if header.mac.count != 32 {
209 | throw ParseError.malformedClosingLine
210 | }
211 | break
212 | }
213 |
214 | var sr = StanzaReader(buf)
215 | let s = try sr.readStanza()
216 | buf = sr.buf // read buf back to get the position advances
217 | header.recipients.append(s)
218 | }
219 |
220 | guard let buf = buf.getBytes(at: buf.readerIndex, length: buf.readableBytes) else {
221 | throw ParseError.internalError
222 | }
223 | let payload = InputStream(data: Data(buf))
224 | payload.open()
225 | return (header, payload)
226 | }
227 |
228 | private static func splitArgs(line: Bytes) -> (String, [String])
229 | where Bytes: Sequence, Bytes.Element == UInt8 {
230 |
231 | var s = String(bytes: line, encoding: .utf8)!
232 | s = s.trimmingCharacters(in: ["\n"])
233 | let parts = s.components(separatedBy: " ")
234 | return (parts[0], Array(parts[1...]))
235 | }
236 |
237 | private static func isValidString(_ s: String) -> Bool {
238 | if s.count == 0 {
239 | return false
240 | }
241 | let bytes = s.data(using: .utf8)!
242 | for c in bytes {
243 | if c < 33 || c > 126 {
244 | return false
245 | }
246 | }
247 | return true
248 | }
249 | }
250 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Parse.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 | import NIOCore
3 |
4 | extension Age {
5 | enum ParseError: Error {
6 | case malformedInput
7 | case parseErrorAtLine(Int)
8 | case noSecretKeysFound
9 | }
10 |
11 | /// Parses a file with one or more private key encodings, one per line.
12 | /// Empty lines and lines starting with "#" are ignored.
13 | public static func parseIdentities(input: InputStream) throws -> [Identity] {
14 | var ids: [Identity] = []
15 | var buf = ByteBuffer(input)
16 |
17 | var n = 0
18 | while buf.readableBytes > 0 {
19 | n += 1
20 | guard var line = buf.readString(until: "\n") else {
21 | throw ParseError.malformedInput
22 | }
23 | line = line.trimmingCharacters(in: .newlines)
24 | if line.hasPrefix("#") || line.isEmpty {
25 | continue
26 | }
27 | guard let id = try X25519Identity(line) else {
28 | throw ParseError.parseErrorAtLine(n)
29 | }
30 | ids.append(id)
31 | }
32 | guard !ids.isEmpty else {
33 | throw ParseError.noSecretKeysFound
34 | }
35 | return ids
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Primitives.swift:
--------------------------------------------------------------------------------
1 | import CryptoKit
2 | import Foundation
3 |
4 | // MARK: - AEAD Encrypt
5 |
6 | extension Age {
7 | /// Encrypt a message with a one-time key.
8 | static func aeadEncrypt(key: SymmetricKey, plaintext: SymmetricKey) throws -> Data {
9 | let p = plaintext.withUnsafeBytes { Data(Array($0)) }
10 | // The nonce is fixed because this function is only used in places where the
11 | // spec guarantees each key is only used once (by deriving it from values
12 | // that include fresh randomness), allowing us to save the overhead.
13 | // For the code that encrypts the actual payload, look at the `Stream` types.
14 | let nonce = try ChaChaPoly.Nonce(data: Data(repeating: 0, count: ChaChaPoly.nonceSize))
15 | // CryptoKit's combined property contains all of nonce+ciphertext+tags.
16 | // For compatibility with other languages we need to exclude the nonce.
17 | return try ChaChaPoly.seal(p, using: key, nonce: nonce).combined.dropFirst(ChaChaPoly.nonceSize)
18 | }
19 |
20 | static func aeadEncrypt(key: [UInt8], plaintext: SymmetricKey) throws -> Data {
21 | let k = SymmetricKey(data: key)
22 | return try aeadEncrypt(key: k, plaintext: plaintext)
23 | }
24 | }
25 |
26 | // MARK: - AEAD Decrypt
27 |
28 | extension Age {
29 | enum AEADError: Error {
30 | case incorrectCiphertextSize
31 | }
32 |
33 | /// Decrypt a message of an expected fixed size.
34 | ///
35 | /// The message size is limited to mitigate multi-key attacks, where a ciphertext
36 | /// can be crafted that decrypts successfully under multiple keys. Short ciphertexts
37 | /// can only target two keys, which has limited impact.
38 | static func aeadDecrypt(key: SymmetricKey, size: Int, ciphertext: Data) throws -> Data {
39 | let nonce = try ChaChaPoly.Nonce(data: Data(repeating: 0, count: ChaChaPoly.nonceSize))
40 | let box = try ChaChaPoly.SealedBox(combined: nonce+ciphertext)
41 | guard ciphertext.count == size+ChaChaPoly.overhead else {
42 | throw AEADError.incorrectCiphertextSize
43 | }
44 | return try ChaChaPoly.open(box, using: key)
45 | }
46 |
47 | static func aeadDecrypt(key: [UInt8], size: Int, ciphertext: Data) throws -> Data {
48 | let aead = SymmetricKey(data: key)
49 | return try aeadDecrypt(key: aead, size: size, ciphertext: ciphertext)
50 | }
51 | }
52 |
53 | // MARK: -
54 |
55 | extension Age {
56 | static func headerMAC(fileKey: SymmetricKey, hdr: Format.Header) throws -> Data {
57 | let h = HKDF.deriveKey(
58 | inputKeyMaterial: fileKey,
59 | info: "header".data(using: .utf8)!,
60 | outputByteCount: SHA256.byteCount)
61 | var hh = HMAC(key: h)
62 | hdr.encodeWithoutMAC(to: &hh)
63 | return Data(hh.finalize())
64 | }
65 |
66 | static func streamKey(fileKey: SymmetricKey, nonce: ContiguousBytes) -> SymmetricKey {
67 | let b = nonce.withUnsafeBytes { bytes in
68 | Data(Array(bytes))
69 | }
70 | return HKDF.deriveKey(
71 | inputKeyMaterial: fileKey,
72 | salt: b,
73 | info: "payload".data(using: .utf8)!,
74 | outputByteCount: SHA256.byteCount)
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/Sources/AgeKit/Scrypt.swift:
--------------------------------------------------------------------------------
1 | import CryptoKit
2 | import ExtrasBase64
3 | import Foundation
4 | import Scrypt
5 |
6 | let scryptLabel = "age-encryption.org/v1/scrypt".data(using: .utf8)!
7 | let scryptSaltSize = 16
8 |
9 | // MARK: - Recipient
10 |
11 | extension Age {
12 | /// A password-based recipient. Anyone with the password can decrypt the message.
13 | ///
14 | /// If an `ScryptRecipient` is used, it must be the only recipient for the file: it can't be mixed
15 | /// with other recipient types and can't be used multiple times for the same file.
16 | ///
17 | /// Its use is not recommended for automated systems, which should prefer `X25519Recipient`.
18 | public struct ScryptRecipient: Recipient {
19 | let password: Data
20 | var workFactor: Int
21 |
22 | /// Create a new `ScryptRecipient` with the provided password.
23 | public init?(password: String) {
24 | if password.isEmpty {
25 | return nil
26 | }
27 | self.password = password.data(using: .utf8)!
28 | self.workFactor = 18
29 | }
30 |
31 | enum WrapError: Error {
32 | case errSecSuccess(Int32)
33 | }
34 |
35 | /// Sets the scrypt work factor to 2^logN. It must be called before `wrap`.
36 | ///
37 | /// This caps the amount of work that `Age.decrypt` might have to do to process
38 | /// received files. If `setWorkFactor` is not called, a fairly high default is used,
39 | /// which might not be suitable for systems processing untrsted files.
40 | public mutating func setWorkFactor(_ logN: Int) {
41 | assert(logN > 1 && logN < 30, "setWorkFactor called with illegal value")
42 | workFactor = logN
43 | }
44 |
45 | public func wrap(fileKey: SymmetricKey) throws -> [Age.Stanza] {
46 | var saltBytes = [UInt8](repeating: 0, count: scryptSaltSize)
47 | let status = SecRandomCopyBytes(kSecRandomDefault, scryptSaltSize, &saltBytes)
48 | guard status == errSecSuccess else {
49 | throw WrapError.errSecSuccess(errSecSuccess)
50 | }
51 |
52 | let saltEncoded = Base64.encodeString(bytes: saltBytes, options: .omitPaddingCharacter)
53 | let args = [saltEncoded, String(workFactor)]
54 |
55 | var salt = scryptLabel
56 | salt.append(saltBytes, count: saltBytes.count)
57 |
58 | let k = try scrypt(
59 | password: password.bytes,
60 | salt: salt.bytes,
61 | length: ChaChaPoly.keySize,
62 | N: 1< 1 && logN < 30, "setMaxWorkFactor called with illegal value")
102 | maxWorkFactor = logN
103 | }
104 |
105 | // TODO: Update this to use the new Regex type in Swift 5.7
106 | private let digitsRe = try! NSRegularExpression(pattern: "^[1-9][0-9]*$")
107 |
108 | public func unwrap(stanzas: [Age.Stanza]) throws -> SymmetricKey {
109 | return try multiUnwrap(stanzas: stanzas) { block in
110 | guard block.type == "scrypt" else {
111 | throw DecryptError.incorrectIdentity
112 | }
113 | guard block.args.count == 2 else {
114 | throw Error.invalidRecipient
115 | }
116 |
117 | var salt = try Format.decodeString(block.args[0])
118 | guard salt.count == scryptSaltSize else {
119 | throw Error.invalidRecipient
120 | }
121 |
122 | let range = NSRange(location: 0, length: block.args[1].count)
123 | guard digitsRe.firstMatch(in: block.args[1], range: range) != nil else {
124 | throw Error.invalidScryptWorkFactor
125 | }
126 | guard let logN = Int(block.args[1]), logN > 0 else {
127 | throw Error.invalidScryptWorkFactor
128 | }
129 | guard logN <= maxWorkFactor else {
130 | throw Error.workFactorTooLarge
131 | }
132 |
133 | salt = Data(scryptLabel.bytes + salt.bytes)
134 | let k = try scrypt(
135 | password: Array(password),
136 | salt: salt.bytes,
137 | length: ChaChaPoly.keySize,
138 | N: 1< Bool {
27 | return self.nonce.last! == lastChunkFlag
28 | }
29 |
30 | public func isZero() -> Bool {
31 | return self.nonce.allSatisfy { $0 == 0 }
32 | }
33 | }
34 |
35 | enum StreamError: Error {
36 | case unexpectedEOF
37 | case lastChunkEmpty
38 | case trailingData
39 | case unexpectedEmptyChunk
40 | case decryptFailure
41 | }
42 |
43 | public struct StreamReader {
44 | private var aead: SymmetricKey
45 | private var src: InputStream
46 |
47 | private var encryptedChunk = Data(capacity: encChunkSize)
48 | private var chunk: Data?
49 | private var nonce = Nonce()
50 |
51 | init(fileKey: SymmetricKey, src: InputStream) {
52 | self.aead = fileKey
53 | self.src = src
54 | }
55 |
56 | public mutating func read(_ buf: inout Data) throws -> Int {
57 | if let chunk = self.chunk, chunk.count > 0 {
58 | // Even though prefix returns Data, re-wrap it to make sure the start index is 0
59 | buf = Data(chunk.prefix(buf.count))
60 | self.chunk = chunk.dropFirst(buf.count)
61 | return buf.count
62 | }
63 | if buf.count == 0 {
64 | return 0
65 | }
66 |
67 | try readChunk()
68 | guard let chunk = self.chunk else {
69 | throw StreamError.unexpectedEmptyChunk
70 | }
71 | buf = Data(chunk.prefix(buf.count))
72 | self.chunk = chunk[buf.count...]
73 | if self.nonce.isLast() {
74 | var b = [UInt8]()
75 | b.reserveCapacity(1)
76 | if self.src.read(&b, maxLength: 1) > 0 {
77 | throw StreamError.trailingData
78 | }
79 | }
80 | return buf.count
81 | }
82 |
83 | private mutating func readChunk() throws {
84 | var buf = [UInt8](repeating: 0, count: encChunkSize)
85 | let n = self.src.read(&buf, maxLength: encChunkSize)
86 | if n == 0 {
87 | throw StreamError.unexpectedEOF
88 | }
89 | self.encryptedChunk = Data(buf[.. Int {
125 | guard buf.count > 0 else {
126 | return 0
127 | }
128 |
129 | var bytesWritten = 0
130 | while !buf.isEmpty {
131 | let toWrite = min(chunkSize - self.chunk.count, buf.count)
132 | self.chunk.append(buf[.. Int {
149 | var d = str.data(using: .utf8)!
150 | return try write(&d)
151 | }
152 |
153 | /// Flushes the last chunk. It does not close the underlying `OutputStream`.
154 | public mutating func close() throws {
155 | try self.flushChunk(last: true)
156 | }
157 |
158 | private mutating func flushChunk(last: Bool) throws {
159 | if !last {
160 | assert(self.chunk.count == chunkSize, "stream: internal error: flush called with partial chunk")
161 | }
162 |
163 | if last {
164 | self.nonce.setLastChunkFlag()
165 | }
166 |
167 | let nonce = try ChaChaPoly.Nonce(data: self.nonce.nonce)
168 | let enc = try ChaChaPoly.seal(self.chunk, using: self.aead, nonce: nonce)
169 | // Note that in other languages seal returns the ciphertext and tag.
170 | // CryptoKit's SealedBox usually works on the .combined property which also contains the nonce.
171 | // To be cross-platform compatible we need to exclude the nonce.
172 | _ = try self.dst.write(enc.combined.dropFirst(Nonce.size))
173 | self.nonce.increment()
174 | }
175 | }
176 |
177 |
--------------------------------------------------------------------------------
/Sources/AgeKit/X25519.swift:
--------------------------------------------------------------------------------
1 | import Bech32
2 | import CryptoKit
3 | import ExtrasBase64
4 | import Foundation
5 |
6 | let x25519Label = "age-encryption.org/v1/X25519"
7 |
8 | // MARK: - Recipient
9 |
10 | extension Age {
11 | /// X25519Recipient is the standard age public key. Messages encrypted to this
12 | /// recipient can be decrypted with the corresponding `X25519Identity`.
13 | ///
14 | /// This recipient is anonymous, in the sense that an attacker can't tell from
15 | /// the message alone if it is encrypted to a certain recipient.
16 | public struct X25519Recipient: Recipient {
17 | enum Error: Swift.Error {
18 | case invalidType
19 | case invalidPublicKey
20 | }
21 |
22 | private let theirPublicKey: Curve25519.KeyAgreement.PublicKey
23 |
24 | /// The Bech32 public key encoding of the recipient.
25 | public var string: String {
26 | try! Bech32.encode(to: "age", data: theirPublicKey.rawRepresentation)
27 | }
28 |
29 | /// Create an X25519Recipient from a Bech32-encoded public key with the "age1" prefix.
30 | public init(_ string: String) throws {
31 | let (hrp, data) = try Bech32.decode(from: string)
32 | if hrp != "age" {
33 | throw Error.invalidType
34 | }
35 | try self.init(data)
36 | }
37 |
38 | fileprivate init(_ publicKey: Data) throws {
39 | guard publicKey.count == Curve25519.pointSize else {
40 | throw Error.invalidPublicKey
41 | }
42 | self.theirPublicKey = try Curve25519.KeyAgreement.PublicKey(rawRepresentation: publicKey)
43 | }
44 |
45 | public func wrap(fileKey: SymmetricKey) throws -> [Stanza] {
46 | let ephemeral = Curve25519.KeyAgreement.PrivateKey()
47 | let ourPublicKey = ephemeral.publicKey.rawRepresentation
48 | let sharedSecret = try ephemeral.sharedSecretFromKeyAgreement(with: theirPublicKey)
49 |
50 | let salt = Data(ourPublicKey + theirPublicKey.rawRepresentation)
51 |
52 | let wrappingKey = sharedSecret.hkdfDerivedSymmetricKey(
53 | using: SHA256.self,
54 | salt: salt,
55 | sharedInfo: x25519Label.data(using: .utf8)!,
56 | outputByteCount: SHA256.byteCount)
57 | let wrappedKey = try aeadEncrypt(key: wrappingKey, plaintext: fileKey)
58 | let b64 = Base64.encodeString(bytes: ourPublicKey, options: .omitPaddingCharacter)
59 | let stanza = Stanza(
60 | type: "X25519",
61 | args: [b64],
62 | body: wrappedKey
63 | )
64 |
65 | return [stanza]
66 | }
67 | }
68 | }
69 |
70 | // MARK: - Identity
71 |
72 | extension Age {
73 | /// X25519Identity is the standard age private key, which can decrypt messages
74 | /// encrypted to the corresponding `X25519Recipient`.
75 | public struct X25519Identity: Identity {
76 | enum Error: Swift.Error {
77 | case malformedSecretKey
78 | case incorrectIdentity
79 | case invalidX25519RecipientBlock
80 | }
81 |
82 | private let secretKey: Curve25519.KeyAgreement.PrivateKey
83 |
84 | /// The Bech32 private key encoding of the identity.
85 | public var string: String {
86 | try! Bech32.encode(to: "AGE-SECRET-KEY-", data: secretKey.rawRepresentation).uppercased()
87 | }
88 |
89 | /// The public `X25519Recipient` value corresponding to this identity.
90 | public var recipient: X25519Recipient {
91 | try! X25519Recipient(secretKey.publicKey.rawRepresentation)
92 | }
93 |
94 | private init(secretKey: Curve25519.KeyAgreement.PrivateKey) {
95 | self.secretKey = secretKey
96 | }
97 |
98 | /// Create an X25519Identity from a Bech32-encoded private key with the "AGE-SECRET-KEY-1" prefix.
99 | ///
100 | /// - Throws: `Error.malformedSecretKey` when the key is incorrectly formatted.
101 | init?(_ string: String) throws {
102 | let (hrp, data) = try Bech32.decode(from: string)
103 | if hrp != "AGE-SECRET-KEY-" {
104 | throw Error.malformedSecretKey
105 | }
106 | self.secretKey = try Curve25519.KeyAgreement.PrivateKey(rawRepresentation: data)
107 | }
108 |
109 | /// Randomly generate a new `X25519Identity`.
110 | public static func generate() -> X25519Identity {
111 | let secretKey = Curve25519.KeyAgreement.PrivateKey()
112 | return X25519Identity(secretKey: secretKey)
113 | }
114 |
115 | public func unwrap(stanzas: [Stanza]) throws -> SymmetricKey {
116 | return try multiUnwrap(stanzas: stanzas) { block in
117 | guard block.type == "X25519" else {
118 | throw DecryptError.incorrectIdentity
119 | }
120 | guard block.args.count == 1 else {
121 | throw Error.invalidX25519RecipientBlock
122 | }
123 | let rawPubKey = try Base64.decode(string: block.args[0], options: .omitPaddingCharacter)
124 | let publicKey = try Curve25519.KeyAgreement.PublicKey(rawRepresentation: Data(rawPubKey))
125 | guard publicKey.rawRepresentation.count == Curve25519.pointSize else {
126 | throw Error.invalidX25519RecipientBlock
127 | }
128 |
129 | let sharedSecret = try secretKey.sharedSecretFromKeyAgreement(with: publicKey)
130 |
131 | // FIXME: publicKey and secretKey.publicKey are the same?
132 | let salt = Data(publicKey.rawRepresentation + secretKey.publicKey.rawRepresentation)
133 | let wrappingKey = sharedSecret.hkdfDerivedSymmetricKey(
134 | using: SHA256.self,
135 | salt: salt,
136 | sharedInfo: x25519Label.data(using: .utf8)!,
137 | outputByteCount: SHA256.byteCount)
138 |
139 | do {
140 | let fileKey = try aeadDecrypt(key: wrappingKey, size: fileKeySize, ciphertext: block.body)
141 | return SymmetricKey(data: fileKey)
142 | } catch {
143 | throw DecryptError.incorrectIdentity
144 | }
145 |
146 | }
147 | }
148 | }
149 | }
150 |
--------------------------------------------------------------------------------
/Sources/Bech32/Bech32.swift:
--------------------------------------------------------------------------------
1 | import Foundation
2 |
3 | private let charset = "qpzry9x8gf2tvdw0s3jn54khce6mua7l".data(using: .utf8)!
4 | private let generator: [UInt32] = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3]
5 |
6 | enum EncodeError: Error {
7 | case invalidHRP, invalidCharacter, mixedCase
8 | }
9 |
10 | enum DecodeError: Error {
11 | case mixedCase, invalidPosition, invalidCharacter, invalidChecksum
12 | }
13 |
14 | enum ConvertError: Error {
15 | case invalidDataRange, illegalZeroPadding, nonZeroPadding
16 | }
17 |
18 | private func polymod(_ values: Data) -> UInt32 {
19 | var chk = UInt32(1)
20 | values.forEach { v in
21 | let top = chk >> 25
22 | chk = (chk & 0x1ffffff) << 5
23 | chk = chk ^ UInt32(v)
24 | for i in 0..<5 {
25 | chk ^= ((top >> i) & 1) == 0 ? 0 : generator[i]
26 | }
27 | }
28 | return chk
29 | }
30 |
31 | private func expandHrp(for hrp: String) -> Data {
32 | guard let h = hrp.lowercased().data(using: .utf8) else { return Data() }
33 | return Data(h.map { $0 >> 5 } + [UInt8(0)] + h.map { $0 & 31 })
34 | }
35 |
36 | private func verifyChecksum(hrp: String, data: Data) -> Bool {
37 | var h = expandHrp(for: hrp)
38 | h.append(data)
39 | return polymod(h) == 1
40 | }
41 |
42 | private func createChecksum(for hrp: String, data: Data) -> Data {
43 | var values = expandHrp(for: hrp)
44 | values.append(contentsOf: data)
45 | values.append(contentsOf: Array(repeating: UInt8(0), count: 6))
46 | let mod = polymod(values) ^ 1
47 | var data = Data()
48 | for i in (0..<6) {
49 | let shift = 5 * (5 - i)
50 | data.append(UInt8(truncatingIfNeeded: mod >> shift) & 31)
51 | }
52 | return data
53 | }
54 |
55 | private func convertBits(data: Data, fromBits: UInt8, toBits: UInt8, pad: Bool) throws -> Data {
56 | var ret = Data()
57 | var acc = UInt32(0)
58 | var bits = UInt8(0)
59 | let maxv = UInt8(1<>fromBits != 0 {
62 | throw ConvertError.invalidDataRange
63 | }
64 | acc = acc<= toBits {
67 | bits -= toBits
68 | ret.append(UInt8(truncatingIfNeeded: acc>>bits)&maxv)
69 | }
70 | }
71 | if pad, bits > 0 {
72 | ret.append(UInt8(truncatingIfNeeded: acc<<(toBits-bits))&maxv)
73 | } else if bits >= fromBits {
74 | throw ConvertError.illegalZeroPadding
75 | }
76 | return ret
77 | }
78 |
79 | public func encode(to hrp: String, data: Data) throws -> String {
80 | let values = try convertBits(data: data, fromBits: 8, toBits: 5, pad: true)
81 | var ret = hrp.data(using: .utf8)!
82 | ret.append("1".data(using: .utf8)!)
83 | for i in values {
84 | ret.append(charset[Int(i)])
85 | }
86 | for i in createChecksum(for: hrp, data: values) {
87 | ret.append(charset[Int(i)])
88 | }
89 | let s = String(data: ret, encoding: .utf8)!
90 | return hrp.lowercased() == hrp ? s : s.uppercased()
91 | }
92 |
93 | public func decode(from: String) throws -> (hrp: String, data: Data) {
94 | if from.lowercased() != from && from.uppercased() != from {
95 | throw DecodeError.mixedCase
96 | }
97 | let str = from.data(using: .utf8)!
98 | guard let marker = from.lastIndex(of: "1") else {
99 | throw DecodeError.invalidPosition
100 | }
101 |
102 | let pos = from.distance(from: from.startIndex, to: marker)
103 | if pos < 1 || pos+7 > from.count {
104 | throw DecodeError.invalidPosition
105 | }
106 | let hrp = str[.. 126 {
109 | throw DecodeError.invalidCharacter
110 | }
111 | }
112 |
113 | var data = Data()
114 | let s = from.lowercased().data(using: .utf8)!
115 | for c in s[(pos+1)...] {
116 | guard let i = charset.firstIndex(of: c) else { throw DecodeError.invalidCharacter }
117 | data.append(UInt8(i))
118 | }
119 | if !verifyChecksum(hrp: String(data: hrp, encoding: .utf8)!, data: data) {
120 | throw DecodeError.invalidChecksum
121 | }
122 |
123 | data = try convertBits(data: data[0.. stepSize {
26 | b = stepSize
27 | }
28 | var d = Data(src[n..
65 | (?P<_indent> [\ \t]* % (?! [{%] ) [\ \t]* ) (?! [\ \t] | ''' +
66 | linesClose + r''' ) .*
67 | ( \n (?P=_indent) (?! ''' + linesClose + r''' ) .* ) *
68 | )
69 | | (?P [\ \t]* % [ \t]* ''' + linesClose + r''' )
70 | | [\ \t]* (?P %\{ )
71 | (?: [^}]| \} (?!%) )* \}% # Absorb
72 | )
73 | \n? # absorb one trailing newline
74 |
75 | # Substitutions
76 | | (?P \$\{ )
77 | [^}]* \} # Absorb
78 |
79 | # %% and $$ are literal % and $ respectively
80 | | (?P[$%]) (?P=symbol)
81 |
82 | # Literal text
83 | | (?P ''' + literalText + r'''
84 | (?:
85 | # newline that doesn't precede space+%
86 | (?: \n (?! [\ \t]* %[^%] ) )
87 | ''' + literalText + r'''
88 | )*
89 | \n?
90 | )
91 | ''', re.VERBOSE | re.MULTILINE)
92 |
93 | gyb_block_close = re.compile(r'\}%[ \t]*\n?')
94 |
95 |
96 | def token_pos_to_index(token_pos, start, line_starts):
97 | """Translate a tokenize (line, column) pair into an absolute
98 | position in source text given the position where we started
99 | tokenizing and a list that maps lines onto their starting
100 | character indexes.
101 | """
102 | relative_token_line_plus1, token_col = token_pos
103 |
104 | # line number where we started tokenizing
105 | start_line_num = bisect(line_starts, start) - 1
106 |
107 | # line number of the token in the whole text
108 | abs_token_line = relative_token_line_plus1 - 1 + start_line_num
109 |
110 | # if found in the first line, adjust the end column to account
111 | # for the extra text
112 | if relative_token_line_plus1 == 1:
113 | token_col += start - line_starts[start_line_num]
114 |
115 | # Sometimes tokenizer errors report a line beyond the last one
116 | if abs_token_line >= len(line_starts):
117 | return line_starts[-1]
118 |
119 | return line_starts[abs_token_line] + token_col
120 |
121 |
122 | def tokenize_python_to_unmatched_close_curly(source_text, start, line_starts):
123 | """Apply Python's tokenize to source_text starting at index start
124 | while matching open and close curly braces. When an unmatched
125 | close curly brace is found, return its index. If not found,
126 | return len(source_text). If there's a tokenization error, return
127 | the position of the error.
128 | """
129 | stream = StringIO(source_text)
130 | stream.seek(start)
131 | nesting = 0
132 |
133 | try:
134 | for kind, text, token_start, token_end, line_text \
135 | in tokenize.generate_tokens(stream.readline):
136 |
137 | if text == '{':
138 | nesting += 1
139 | elif text == '}':
140 | nesting -= 1
141 | if nesting < 0:
142 | return token_pos_to_index(token_start, start, line_starts)
143 |
144 | except tokenize.TokenError as error:
145 | (message, error_pos) = error.args
146 | return token_pos_to_index(error_pos, start, line_starts)
147 |
148 | return len(source_text)
149 |
150 |
151 | def tokenize_template(template_text):
152 | r"""Given the text of a template, returns an iterator over
153 | (tokenType, token, match) tuples.
154 |
155 | **Note**: this is template syntax tokenization, not Python
156 | tokenization.
157 |
158 | When a non-literal token is matched, a client may call
159 | iter.send(pos) on the iterator to reset the position in
160 | template_text at which scanning will resume.
161 |
162 | This function provides a base level of tokenization which is
163 | then refined by ParseContext.token_generator.
164 |
165 | >>> from pprint import *
166 | >>> pprint(list((kind, text) for kind, text, _ in tokenize_template(
167 | ... '%for x in range(10):\n% print x\n%end\njuicebox')))
168 | [('gybLines', '%for x in range(10):\n% print x'),
169 | ('gybLinesClose', '%end'),
170 | ('literal', 'juicebox')]
171 |
172 | >>> pprint(list((kind, text) for kind, text, _ in tokenize_template(
173 | ... '''Nothing
174 | ... % if x:
175 | ... % for i in range(3):
176 | ... ${i}
177 | ... % end
178 | ... % else:
179 | ... THIS SHOULD NOT APPEAR IN THE OUTPUT
180 | ... ''')))
181 | [('literal', 'Nothing\n'),
182 | ('gybLines', '% if x:\n% for i in range(3):'),
183 | ('substitutionOpen', '${'),
184 | ('literal', '\n'),
185 | ('gybLinesClose', '% end'),
186 | ('gybLines', '% else:'),
187 | ('literal', 'THIS SHOULD NOT APPEAR IN THE OUTPUT\n')]
188 |
189 | >>> for kind, text, _ in tokenize_template('''
190 | ... This is $some$ literal stuff containing a ${substitution}
191 | ... followed by a %{...} block:
192 | ... %{
193 | ... # Python code
194 | ... }%
195 | ... and here $${are} some %-lines:
196 | ... % x = 1
197 | ... % y = 2
198 | ... % if z == 3:
199 | ... % print '${hello}'
200 | ... % end
201 | ... % for x in zz:
202 | ... % print x
203 | ... % # different indentation
204 | ... % twice
205 | ... and some lines that literally start with a %% token
206 | ... %% first line
207 | ... %% second line
208 | ... '''):
209 | ... print((kind, text.strip().split('\n',1)[0]))
210 | ('literal', 'This is $some$ literal stuff containing a')
211 | ('substitutionOpen', '${')
212 | ('literal', 'followed by a %{...} block:')
213 | ('gybBlockOpen', '%{')
214 | ('literal', 'and here ${are} some %-lines:')
215 | ('gybLines', '% x = 1')
216 | ('gybLinesClose', '% end')
217 | ('gybLines', '% for x in zz:')
218 | ('gybLines', '% # different indentation')
219 | ('gybLines', '% twice')
220 | ('literal', 'and some lines that literally start with a % token')
221 | """
222 | pos = 0
223 | end = len(template_text)
224 |
225 | saved_literal = []
226 | literal_first_match = None
227 |
228 | while pos < end:
229 | m = tokenize_re.match(template_text, pos, end)
230 |
231 | # pull out the one matched key (ignoring internal patterns starting
232 | # with _)
233 | ((kind, text), ) = (
234 | (kind, text) for (kind, text) in m.groupdict().items()
235 | if text is not None and kind[0] != '_')
236 |
237 | if kind in ('literal', 'symbol'):
238 | if len(saved_literal) == 0:
239 | literal_first_match = m
240 | # literals and symbols get batched together
241 | saved_literal.append(text)
242 | pos = None
243 | else:
244 | # found a non-literal. First yield any literal we've accumulated
245 | if saved_literal != []:
246 | yield 'literal', ''.join(saved_literal), literal_first_match
247 | saved_literal = []
248 |
249 | # Then yield the thing we found. If we get a reply, it's
250 | # the place to resume tokenizing
251 | pos = yield kind, text, m
252 |
253 | # If we were not sent a new position by our client, resume
254 | # tokenizing at the end of this match.
255 | if pos is None:
256 | pos = m.end(0)
257 | else:
258 | # Client is not yet ready to process next token
259 | yield
260 |
261 | if saved_literal != []:
262 | yield 'literal', ''.join(saved_literal), literal_first_match
263 |
264 |
265 | def split_gyb_lines(source_lines):
266 | r"""Return a list of lines at which to split the incoming source
267 |
268 | These positions represent the beginnings of python line groups that
269 | will require a matching %end construct if they are to be closed.
270 |
271 | >>> src = split_lines('''\
272 | ... if x:
273 | ... print x
274 | ... if y: # trailing comment
275 | ... print z
276 | ... if z: # another comment\
277 | ... ''')
278 | >>> s = split_gyb_lines(src)
279 | >>> len(s)
280 | 2
281 | >>> src[s[0]]
282 | ' print z\n'
283 | >>> s[1] - len(src)
284 | 0
285 |
286 | >>> src = split_lines('''\
287 | ... if x:
288 | ... if y: print 1
289 | ... if z:
290 | ... print 2
291 | ... pass\
292 | ... ''')
293 | >>> s = split_gyb_lines(src)
294 | >>> len(s)
295 | 1
296 | >>> src[s[0]]
297 | ' if y: print 1\n'
298 |
299 | >>> src = split_lines('''\
300 | ... if x:
301 | ... if y:
302 | ... print 1
303 | ... print 2
304 | ... ''')
305 | >>> s = split_gyb_lines(src)
306 | >>> len(s)
307 | 2
308 | >>> src[s[0]]
309 | ' if y:\n'
310 | >>> src[s[1]]
311 | ' print 1\n'
312 | """
313 | last_token_text, last_token_kind = None, None
314 | unmatched_indents = []
315 |
316 | dedents = 0
317 | try:
318 | for token_kind, token_text, token_start, \
319 | (token_end_line, token_end_col), line_text \
320 | in tokenize.generate_tokens(lambda i=iter(source_lines):
321 | next(i)):
322 |
323 | if token_kind in (tokenize.COMMENT, tokenize.ENDMARKER):
324 | continue
325 |
326 | if token_text == '\n' and last_token_text == ':':
327 | unmatched_indents.append(token_end_line)
328 |
329 | # The tokenizer appends dedents at EOF; don't consider
330 | # those as matching indentations. Instead just save them
331 | # up...
332 | if last_token_kind == tokenize.DEDENT:
333 | dedents += 1
334 | # And count them later, when we see something real.
335 | if token_kind != tokenize.DEDENT and dedents > 0:
336 | unmatched_indents = unmatched_indents[:-dedents]
337 | dedents = 0
338 |
339 | last_token_text, last_token_kind = token_text, token_kind
340 |
341 | except tokenize.TokenError:
342 | # Let the later compile() call report the error
343 | return []
344 |
345 | if last_token_text == ':':
346 | unmatched_indents.append(len(source_lines))
347 |
348 | return unmatched_indents
349 |
350 |
351 | def code_starts_with_dedent_keyword(source_lines):
352 | r"""Return True iff the incoming Python source_lines begin with "else",
353 | "elif", "except", or "finally".
354 |
355 | Initial comments and whitespace are ignored.
356 |
357 | >>> code_starts_with_dedent_keyword(split_lines('if x in y: pass'))
358 | False
359 | >>> code_starts_with_dedent_keyword(split_lines('except ifSomethingElse:'))
360 | True
361 | >>> code_starts_with_dedent_keyword(
362 | ... split_lines('\n# comment\nelse: # yes'))
363 | True
364 | """
365 | token_text = None
366 | for token_kind, token_text, _, _, _ \
367 | in tokenize.generate_tokens(lambda i=iter(source_lines): next(i)):
368 |
369 | if token_kind != tokenize.COMMENT and token_text.strip() != '':
370 | break
371 |
372 | return token_text in ('else', 'elif', 'except', 'finally')
373 |
374 |
375 | class ParseContext(object):
376 |
377 | """State carried through a parse of a template"""
378 |
379 | filename = ''
380 | template = ''
381 | line_starts = []
382 | code_start_line = -1
383 | code_text = None
384 | tokens = None # The rest of the tokens
385 | close_lines = False
386 |
387 | def __init__(self, filename, template=None):
388 | self.filename = os.path.abspath(filename)
389 | if sys.platform == 'win32':
390 | self.filename = '/'.join(self.filename.split(os.sep))
391 | if template is None:
392 | with io.open(os.path.normpath(filename), encoding='utf-8') as f:
393 | self.template = f.read()
394 | else:
395 | self.template = template
396 | self.line_starts = get_line_starts(self.template)
397 | self.tokens = self.token_generator(tokenize_template(self.template))
398 | self.next_token()
399 |
400 | def pos_to_line(self, pos):
401 | return bisect(self.line_starts, pos) - 1
402 |
403 | def token_generator(self, base_tokens):
404 | r"""Given an iterator over (kind, text, match) triples (see
405 | tokenize_template above), return a refined iterator over
406 | token_kinds.
407 |
408 | Among other adjustments to the elements found by base_tokens,
409 | this refined iterator tokenizes python code embedded in
410 | template text to help determine its true extent. The
411 | expression "base_tokens.send(pos)" is used to reset the index at
412 | which base_tokens resumes scanning the underlying text.
413 |
414 | >>> ctx = ParseContext('dummy', '''
415 | ... %for x in y:
416 | ... % print x
417 | ... % end
418 | ... literally
419 | ... ''')
420 | >>> while ctx.token_kind:
421 | ... print((ctx.token_kind, ctx.code_text or ctx.token_text))
422 | ... ignored = ctx.next_token()
423 | ('literal', '\n')
424 | ('gybLinesOpen', 'for x in y:\n')
425 | ('gybLines', ' print x\n')
426 | ('gybLinesClose', '% end')
427 | ('literal', 'literally\n')
428 |
429 | >>> ctx = ParseContext('dummy',
430 | ... '''Nothing
431 | ... % if x:
432 | ... % for i in range(3):
433 | ... ${i}
434 | ... % end
435 | ... % else:
436 | ... THIS SHOULD NOT APPEAR IN THE OUTPUT
437 | ... ''')
438 | >>> while ctx.token_kind:
439 | ... print((ctx.token_kind, ctx.code_text or ctx.token_text))
440 | ... ignored = ctx.next_token()
441 | ('literal', 'Nothing\n')
442 | ('gybLinesOpen', 'if x:\n')
443 | ('gybLinesOpen', ' for i in range(3):\n')
444 | ('substitutionOpen', 'i')
445 | ('literal', '\n')
446 | ('gybLinesClose', '% end')
447 | ('gybLinesOpen', 'else:\n')
448 | ('literal', 'THIS SHOULD NOT APPEAR IN THE OUTPUT\n')
449 |
450 | >>> ctx = ParseContext('dummy',
451 | ... '''% for x in [1, 2, 3]:
452 | ... % if x == 1:
453 | ... literal1
454 | ... % elif x > 1: # add output line here to fix bug
455 | ... % if x == 2:
456 | ... literal2
457 | ... % end
458 | ... % end
459 | ... % end
460 | ... ''')
461 | >>> while ctx.token_kind:
462 | ... print((ctx.token_kind, ctx.code_text or ctx.token_text))
463 | ... ignored = ctx.next_token()
464 | ('gybLinesOpen', 'for x in [1, 2, 3]:\n')
465 | ('gybLinesOpen', ' if x == 1:\n')
466 | ('literal', 'literal1\n')
467 | ('gybLinesOpen', 'elif x > 1: # add output line here to fix bug\n')
468 | ('gybLinesOpen', ' if x == 2:\n')
469 | ('literal', 'literal2\n')
470 | ('gybLinesClose', '% end')
471 | ('gybLinesClose', '% end')
472 | ('gybLinesClose', '% end')
473 | """
474 | for self.token_kind, self.token_text, self.token_match in base_tokens:
475 | kind = self.token_kind
476 | self.code_text = None
477 |
478 | # Do we need to close the current lines?
479 | self.close_lines = kind == 'gybLinesClose'
480 |
481 | # %{...}% and ${...} constructs
482 | if kind.endswith('Open'):
483 |
484 | # Tokenize text that follows as Python up to an unmatched '}'
485 | code_start = self.token_match.end(kind)
486 | self.code_start_line = self.pos_to_line(code_start)
487 |
488 | close_pos = tokenize_python_to_unmatched_close_curly(
489 | self.template, code_start, self.line_starts)
490 | self.code_text = self.template[code_start:close_pos]
491 | yield kind
492 |
493 | if (kind == 'gybBlockOpen'):
494 | # Absorb any '}% \n'
495 | m2 = gyb_block_close.match(self.template, close_pos)
496 | if not m2:
497 | raise ValueError("Invalid block closure")
498 | next_pos = m2.end(0)
499 | else:
500 | assert kind == 'substitutionOpen'
501 | # skip past the closing '}'
502 | next_pos = close_pos + 1
503 |
504 | # Resume tokenizing after the end of the code.
505 | base_tokens.send(next_pos)
506 |
507 | elif kind == 'gybLines':
508 |
509 | self.code_start_line = self.pos_to_line(
510 | self.token_match.start('gybLines'))
511 | indentation = self.token_match.group('_indent')
512 |
513 | # Strip off the leading indentation and %-sign
514 | source_lines = re.split(
515 | '^' + re.escape(indentation),
516 | self.token_match.group('gybLines') + '\n',
517 | flags=re.MULTILINE)[1:]
518 |
519 | if code_starts_with_dedent_keyword(source_lines):
520 | self.close_lines = True
521 |
522 | last_split = 0
523 | for line in split_gyb_lines(source_lines):
524 | self.token_kind = 'gybLinesOpen'
525 | self.code_text = ''.join(source_lines[last_split:line])
526 | yield self.token_kind
527 | last_split = line
528 | self.code_start_line += line - last_split
529 | self.close_lines = False
530 |
531 | self.code_text = ''.join(source_lines[last_split:])
532 | if self.code_text:
533 | self.token_kind = 'gybLines'
534 | yield self.token_kind
535 | else:
536 | yield self.token_kind
537 |
538 | def next_token(self):
539 | """Move to the next token"""
540 | for kind in self.tokens:
541 | return self.token_kind
542 |
543 | self.token_kind = None
544 |
545 |
546 | _default_line_directive = \
547 | '// ###sourceLocation(file: "%(file)s", line: %(line)d)'
548 |
549 |
550 | class ExecutionContext(object):
551 |
552 | """State we pass around during execution of a template"""
553 |
554 | def __init__(self, line_directive=_default_line_directive,
555 | **local_bindings):
556 | self.local_bindings = local_bindings
557 | self.line_directive = line_directive
558 | self.local_bindings['__context__'] = self
559 | self.result_text = []
560 | self.last_file_line = None
561 |
562 | def append_text(self, text, file, line):
563 | # see if we need to inject a line marker
564 | if self.line_directive:
565 | if (file, line) != self.last_file_line:
566 | # We can only insert the line directive at a line break
567 | if len(self.result_text) == 0 \
568 | or self.result_text[-1].endswith('\n'):
569 | substitutions = {'file': file, 'line': line + 1}
570 | format_str = self.line_directive + '\n'
571 | self.result_text.append(format_str % substitutions)
572 | # But if the new text contains any line breaks, we can create
573 | # one
574 | elif '\n' in text:
575 | i = text.find('\n')
576 | self.result_text.append(text[:i + 1])
577 | # and try again
578 | self.append_text(text[i + 1:], file, line + 1)
579 | return
580 |
581 | self.result_text.append(text)
582 | self.last_file_line = (file, line + text.count('\n'))
583 |
584 |
585 | class ASTNode(object):
586 |
587 | """Abstract base class for template AST nodes"""
588 |
589 | def __init__(self):
590 | raise NotImplementedError("ASTNode.__init__ is not implemented.")
591 |
592 | def execute(self, context):
593 | raise NotImplementedError("ASTNode.execute is not implemented.")
594 |
595 | def __str__(self, indent=''):
596 | raise NotImplementedError("ASTNode.__str__ is not implemented.")
597 |
598 | def format_children(self, indent):
599 | if not self.children:
600 | return ' []'
601 |
602 | return '\n'.join(
603 | ['', indent + '['] +
604 | [x.__str__(indent + 4 * ' ') for x in self.children] +
605 | [indent + ']'])
606 |
607 |
608 | class Block(ASTNode):
609 |
610 | """A sequence of other AST nodes, to be executed in order"""
611 |
612 | children = []
613 |
614 | def __init__(self, context):
615 | self.children = []
616 |
617 | while context.token_kind and not context.close_lines:
618 | if context.token_kind == 'literal':
619 | node = Literal
620 | else:
621 | node = Code
622 | self.children.append(node(context))
623 |
624 | def execute(self, context):
625 | for x in self.children:
626 | x.execute(context)
627 |
628 | def __str__(self, indent=''):
629 | return indent + 'Block:' + self.format_children(indent)
630 |
631 |
632 | class Literal(ASTNode):
633 |
634 | """An AST node that generates literal text"""
635 |
636 | def __init__(self, context):
637 | self.text = context.token_text
638 | start_position = context.token_match.start(context.token_kind)
639 | self.start_line_number = context.pos_to_line(start_position)
640 | self.filename = context.filename
641 | context.next_token()
642 |
643 | def execute(self, context):
644 | context.append_text(self.text, self.filename, self.start_line_number)
645 |
646 | def __str__(self, indent=''):
647 | return '\n'.join(
648 | [indent + x for x in ['Literal:'] +
649 | strip_trailing_nl(self.text).split('\n')])
650 |
651 |
652 | class Code(ASTNode):
653 |
654 | """An AST node that is evaluated as Python"""
655 |
656 | code = None
657 | children = ()
658 | kind = None
659 |
660 | def __init__(self, context):
661 |
662 | source = ''
663 | source_line_count = 0
664 |
665 | def accumulate_code():
666 | s = source + (context.code_start_line - source_line_count) * '\n' \
667 | + textwrap.dedent(context.code_text)
668 | line_count = context.code_start_line + \
669 | context.code_text.count('\n')
670 | context.next_token()
671 | return s, line_count
672 |
673 | eval_exec = 'exec'
674 | if context.token_kind.startswith('substitution'):
675 | eval_exec = 'eval'
676 | source, source_line_count = accumulate_code()
677 | source = '(' + source.strip() + ')'
678 |
679 | else:
680 | while context.token_kind == 'gybLinesOpen':
681 | source, source_line_count = accumulate_code()
682 | source += ' __children__[%d].execute(__context__)\n' % len(
683 | self.children)
684 | source_line_count += 1
685 |
686 | self.children += (Block(context),)
687 |
688 | if context.token_kind == 'gybLinesClose':
689 | context.next_token()
690 |
691 | if context.token_kind == 'gybLines':
692 | source, source_line_count = accumulate_code()
693 |
694 | # Only handle a substitution as part of this code block if
695 | # we don't already have some %-lines.
696 | elif context.token_kind == 'gybBlockOpen':
697 |
698 | # Opening ${...} and %{...}% constructs
699 | source, source_line_count = accumulate_code()
700 |
701 | self.filename = context.filename
702 | self.start_line_number = context.code_start_line
703 | self.code = compile(source, context.filename, eval_exec)
704 | self.source = source
705 |
706 | def execute(self, context):
707 | # Save __children__ from the local bindings
708 | save_children = context.local_bindings.get('__children__')
709 | # Execute the code with our __children__ in scope
710 | context.local_bindings['__children__'] = self.children
711 | context.local_bindings['__file__'] = self.filename
712 | result = eval(self.code, context.local_bindings)
713 |
714 | if context.local_bindings['__children__'] is not self.children:
715 | raise ValueError("The code is not allowed to mutate __children__")
716 | # Restore the bindings
717 | context.local_bindings['__children__'] = save_children
718 |
719 | # If we got a result, the code was an expression, so append
720 | # its value
721 | if result is not None \
722 | or (isinstance(result, str) and result != ''):
723 | from numbers import Number, Integral
724 | result_string = None
725 | if isinstance(result, Number) and not isinstance(result, Integral):
726 | result_string = repr(result)
727 | elif isinstance(result, Integral) or isinstance(result, list):
728 | result_string = str(result)
729 | else:
730 | result_string = result
731 | context.append_text(
732 | result_string, self.filename, self.start_line_number)
733 |
734 | def __str__(self, indent=''):
735 | source_lines = re.sub(r'^\n', '', strip_trailing_nl(
736 | self.source), flags=re.MULTILINE).split('\n')
737 | if len(source_lines) == 1:
738 | s = indent + 'Code: {' + source_lines[0] + '}'
739 | else:
740 | s = indent + 'Code:\n' + indent + '{\n' + '\n'.join(
741 | indent + 4 * ' ' + line for line in source_lines
742 | ) + '\n' + indent + '}'
743 | return s + self.format_children(indent)
744 |
745 |
746 | def expand(filename, line_directive=_default_line_directive, **local_bindings):
747 | r"""Return the contents of the given template file, executed with the given
748 | local bindings.
749 |
750 | >>> from tempfile import NamedTemporaryFile
751 | >>> # On Windows, the name of a NamedTemporaryFile cannot be used to open
752 | >>> # the file for a second time if delete=True. Therefore, we have to
753 | >>> # manually handle closing and deleting this file to allow us to open
754 | >>> # the file by its name across all platforms.
755 | >>> f = NamedTemporaryFile(delete=False)
756 | >>> _ = f.write(
757 | ... br'''---
758 | ... % for i in range(int(x)):
759 | ... a pox on ${i} for epoxy
760 | ... % end
761 | ... ${120 +
762 | ...
763 | ... 3}
764 | ... abc
765 | ... ${"w\nx\nX\ny"}
766 | ... z
767 | ... ''')
768 | >>> f.flush()
769 | >>> result = expand(
770 | ... f.name,
771 | ... line_directive='//#sourceLocation(file: "%(file)s", ' + \
772 | ... 'line: %(line)d)',
773 | ... x=2
774 | ... ).replace(
775 | ... '"%s"' % f.name.replace('\\', '/'), '"dummy.file"')
776 | >>> print(result, end='')
777 | //#sourceLocation(file: "dummy.file", line: 1)
778 | ---
779 | //#sourceLocation(file: "dummy.file", line: 3)
780 | a pox on 0 for epoxy
781 | //#sourceLocation(file: "dummy.file", line: 3)
782 | a pox on 1 for epoxy
783 | //#sourceLocation(file: "dummy.file", line: 5)
784 | 123
785 | //#sourceLocation(file: "dummy.file", line: 8)
786 | abc
787 | w
788 | x
789 | X
790 | y
791 | //#sourceLocation(file: "dummy.file", line: 10)
792 | z
793 | >>> f.close()
794 | >>> os.remove(f.name)
795 | """
796 | with io.open(filename, encoding='utf-8') as f:
797 | t = parse_template(filename, f.read())
798 | d = os.getcwd()
799 | os.chdir(os.path.dirname(os.path.abspath(filename)))
800 | try:
801 | return execute_template(
802 | t, line_directive=line_directive, **local_bindings)
803 | finally:
804 | os.chdir(d)
805 |
806 |
807 | def parse_template(filename, text=None):
808 | r"""Return an AST corresponding to the given template file.
809 |
810 | If text is supplied, it is assumed to be the contents of the file,
811 | as a string.
812 |
813 | >>> print(parse_template('dummy.file', text=
814 | ... '''% for x in [1, 2, 3]:
815 | ... % if x == 1:
816 | ... literal1
817 | ... % elif x > 1: # add output line after this line to fix bug
818 | ... % if x == 2:
819 | ... literal2
820 | ... % end
821 | ... % end
822 | ... % end
823 | ... '''))
824 | Block:
825 | [
826 | Code:
827 | {
828 | for x in [1, 2, 3]:
829 | __children__[0].execute(__context__)
830 | }
831 | [
832 | Block:
833 | [
834 | Code:
835 | {
836 | if x == 1:
837 | __children__[0].execute(__context__)
838 | elif x > 1: # add output line after this line to fix bug
839 | __children__[1].execute(__context__)
840 | }
841 | [
842 | Block:
843 | [
844 | Literal:
845 | literal1
846 | ]
847 | Block:
848 | [
849 | Code:
850 | {
851 | if x == 2:
852 | __children__[0].execute(__context__)
853 | }
854 | [
855 | Block:
856 | [
857 | Literal:
858 | literal2
859 | ]
860 | ]
861 | ]
862 | ]
863 | ]
864 | ]
865 | ]
866 |
867 | >>> print(parse_template(
868 | ... 'dummy.file',
869 | ... text='%for x in range(10):\n% print(x)\n%end\njuicebox'))
870 | Block:
871 | [
872 | Code:
873 | {
874 | for x in range(10):
875 | __children__[0].execute(__context__)
876 | }
877 | [
878 | Block:
879 | [
880 | Code: {print(x)} []
881 | ]
882 | ]
883 | Literal:
884 | juicebox
885 | ]
886 |
887 | >>> print(parse_template('/dummy.file', text=
888 | ... '''Nothing
889 | ... % if x:
890 | ... % for i in range(3):
891 | ... ${i}
892 | ... % end
893 | ... % else:
894 | ... THIS SHOULD NOT APPEAR IN THE OUTPUT
895 | ... '''))
896 | Block:
897 | [
898 | Literal:
899 | Nothing
900 | Code:
901 | {
902 | if x:
903 | __children__[0].execute(__context__)
904 | else:
905 | __children__[1].execute(__context__)
906 | }
907 | [
908 | Block:
909 | [
910 | Code:
911 | {
912 | for i in range(3):
913 | __children__[0].execute(__context__)
914 | }
915 | [
916 | Block:
917 | [
918 | Code: {(i)} []
919 | Literal:
920 |
921 | ]
922 | ]
923 | ]
924 | Block:
925 | [
926 | Literal:
927 | THIS SHOULD NOT APPEAR IN THE OUTPUT
928 | ]
929 | ]
930 | ]
931 |
932 | >>> print(parse_template('dummy.file', text='''%
933 | ... %for x in y:
934 | ... % print(y)
935 | ... '''))
936 | Block:
937 | [
938 | Code:
939 | {
940 | for x in y:
941 | __children__[0].execute(__context__)
942 | }
943 | [
944 | Block:
945 | [
946 | Code: {print(y)} []
947 | ]
948 | ]
949 | ]
950 |
951 | >>> print(parse_template('dummy.file', text='''%
952 | ... %if x:
953 | ... % print(y)
954 | ... AAAA
955 | ... %else:
956 | ... BBBB
957 | ... '''))
958 | Block:
959 | [
960 | Code:
961 | {
962 | if x:
963 | __children__[0].execute(__context__)
964 | else:
965 | __children__[1].execute(__context__)
966 | }
967 | [
968 | Block:
969 | [
970 | Code: {print(y)} []
971 | Literal:
972 | AAAA
973 | ]
974 | Block:
975 | [
976 | Literal:
977 | BBBB
978 | ]
979 | ]
980 | ]
981 |
982 | >>> print(parse_template('dummy.file', text='''%
983 | ... %if x:
984 | ... % print(y)
985 | ... AAAA
986 | ... %# This is a comment
987 | ... %else:
988 | ... BBBB
989 | ... '''))
990 | Block:
991 | [
992 | Code:
993 | {
994 | if x:
995 | __children__[0].execute(__context__)
996 | # This is a comment
997 | else:
998 | __children__[1].execute(__context__)
999 | }
1000 | [
1001 | Block:
1002 | [
1003 | Code: {print(y)} []
1004 | Literal:
1005 | AAAA
1006 | ]
1007 | Block:
1008 | [
1009 | Literal:
1010 | BBBB
1011 | ]
1012 | ]
1013 | ]
1014 |
1015 | >>> print(parse_template('dummy.file', text='''\
1016 | ... %for x in y:
1017 | ... AAAA
1018 | ... %if x:
1019 | ... BBBB
1020 | ... %end
1021 | ... CCCC
1022 | ... '''))
1023 | Block:
1024 | [
1025 | Code:
1026 | {
1027 | for x in y:
1028 | __children__[0].execute(__context__)
1029 | }
1030 | [
1031 | Block:
1032 | [
1033 | Literal:
1034 | AAAA
1035 | Code:
1036 | {
1037 | if x:
1038 | __children__[0].execute(__context__)
1039 | }
1040 | [
1041 | Block:
1042 | [
1043 | Literal:
1044 | BBBB
1045 | ]
1046 | ]
1047 | Literal:
1048 | CCCC
1049 | ]
1050 | ]
1051 | ]
1052 | """
1053 | return Block(ParseContext(filename, text))
1054 |
1055 |
1056 | def execute_template(
1057 | ast, line_directive=_default_line_directive, **local_bindings):
1058 | r"""Return the text generated by executing the given template AST.
1059 |
1060 | Keyword arguments become local variable bindings in the execution context
1061 |
1062 | >>> root_directory = os.path.abspath('/')
1063 | >>> file_name = (root_directory + 'dummy.file').replace('\\', '/')
1064 | >>> ast = parse_template(file_name, text=
1065 | ... '''Nothing
1066 | ... % if x:
1067 | ... % for i in range(3):
1068 | ... ${i}
1069 | ... % end
1070 | ... % else:
1071 | ... THIS SHOULD NOT APPEAR IN THE OUTPUT
1072 | ... ''')
1073 | >>> out = execute_template(ast,
1074 | ... line_directive='//#sourceLocation(file: "%(file)s", line: %(line)d)',
1075 | ... x=1)
1076 | >>> out = out.replace(file_name, "DUMMY-FILE")
1077 | >>> print(out, end="")
1078 | //#sourceLocation(file: "DUMMY-FILE", line: 1)
1079 | Nothing
1080 | //#sourceLocation(file: "DUMMY-FILE", line: 4)
1081 | 0
1082 | //#sourceLocation(file: "DUMMY-FILE", line: 4)
1083 | 1
1084 | //#sourceLocation(file: "DUMMY-FILE", line: 4)
1085 | 2
1086 |
1087 | >>> ast = parse_template(file_name, text=
1088 | ... '''Nothing
1089 | ... % a = []
1090 | ... % for x in range(3):
1091 | ... % a.append(x)
1092 | ... % end
1093 | ... ${a}
1094 | ... ''')
1095 | >>> out = execute_template(ast,
1096 | ... line_directive='//#sourceLocation(file: "%(file)s", line: %(line)d)',
1097 | ... x=1)
1098 | >>> out = out.replace(file_name, "DUMMY-FILE")
1099 | >>> print(out, end="")
1100 | //#sourceLocation(file: "DUMMY-FILE", line: 1)
1101 | Nothing
1102 | //#sourceLocation(file: "DUMMY-FILE", line: 6)
1103 | [0, 1, 2]
1104 |
1105 | >>> ast = parse_template(file_name, text=
1106 | ... '''Nothing
1107 | ... % a = []
1108 | ... % for x in range(3):
1109 | ... % a.append(x)
1110 | ... % end
1111 | ... ${a}
1112 | ... ''')
1113 | >>> out = execute_template(ast,
1114 | ... line_directive='#line %(line)d "%(file)s"', x=1)
1115 | >>> out = out.replace(file_name, "DUMMY-FILE")
1116 | >>> print(out, end="")
1117 | #line 1 "DUMMY-FILE"
1118 | Nothing
1119 | #line 6 "DUMMY-FILE"
1120 | [0, 1, 2]
1121 | """
1122 | execution_context = ExecutionContext(
1123 | line_directive=line_directive, **local_bindings)
1124 | ast.execute(execution_context)
1125 | return ''.join(execution_context.result_text)
1126 |
1127 |
1128 | def main():
1129 | import argparse
1130 | import sys
1131 |
1132 | parser = argparse.ArgumentParser(
1133 | formatter_class=argparse.RawDescriptionHelpFormatter,
1134 | description='Generate Your Boilerplate!', epilog='''
1135 | A GYB template consists of the following elements:
1136 |
1137 | - Literal text which is inserted directly into the output
1138 |
1139 | - %% or $$ in literal text, which insert literal '%' and '$'
1140 | symbols respectively.
1141 |
1142 | - Substitutions of the form ${}. The Python
1143 | expression is converted to a string and the result is inserted
1144 | into the output.
1145 |
1146 | - Python code delimited by %{...}%. Typically used to inject
1147 | definitions (functions, classes, variable bindings) into the
1148 | evaluation context of the template. Common indentation is
1149 | stripped, so you can add as much indentation to the beginning
1150 | of this code as you like
1151 |
1152 | - Lines beginning with optional whitespace followed by a single
1153 | '%' and Python code. %-lines allow you to nest other
1154 | constructs inside them. To close a level of nesting, use the
1155 | "%end" construct.
1156 |
1157 | - Lines beginning with optional whitespace and followed by a
1158 | single '%' and the token "end", which close open constructs in
1159 | %-lines.
1160 |
1161 | Example template:
1162 |
1163 | - Hello -
1164 | %{
1165 | x = 42
1166 | def succ(a):
1167 | return a+1
1168 | }%
1169 |
1170 | I can assure you that ${x} < ${succ(x)}
1171 |
1172 | % if int(y) > 7:
1173 | % for i in range(3):
1174 | y is greater than seven!
1175 | % end
1176 | % else:
1177 | y is less than or equal to seven
1178 | % end
1179 |
1180 | - The End. -
1181 |
1182 | When run with "gyb -Dy=9", the output is
1183 |
1184 | - Hello -
1185 |
1186 | I can assure you that 42 < 43
1187 |
1188 | y is greater than seven!
1189 | y is greater than seven!
1190 | y is greater than seven!
1191 |
1192 | - The End. -
1193 | '''
1194 | )
1195 | parser.add_argument(
1196 | '-D', action='append', dest='defines', metavar='NAME=VALUE',
1197 | default=[],
1198 | help='''Bindings to be set in the template's execution context''')
1199 |
1200 | parser.add_argument(
1201 | 'file', type=str,
1202 | help='Path to GYB template file (defaults to stdin)', nargs='?',
1203 | default='-')
1204 | parser.add_argument(
1205 | '-o', dest='target', type=str,
1206 | help='Output file (defaults to stdout)', default='-')
1207 | parser.add_argument(
1208 | '--test', action='store_true',
1209 | default=False, help='Run a self-test')
1210 | parser.add_argument(
1211 | '--verbose-test', action='store_true',
1212 | default=False, help='Run a verbose self-test')
1213 | parser.add_argument(
1214 | '--dump', action='store_true',
1215 | default=False, help='Dump the parsed template to stdout')
1216 | parser.add_argument(
1217 | '--line-directive',
1218 | default=_default_line_directive,
1219 | help='''
1220 | Line directive format string, which will be
1221 | provided 2 substitutions, `%%(line)d` and `%%(file)s`.
1222 |
1223 | Example: `#sourceLocation(file: "%%(file)s", line: %%(line)d)`
1224 |
1225 | The default works automatically with the `line-directive` tool,
1226 | which see for more information.
1227 | ''')
1228 |
1229 | args = parser.parse_args(sys.argv[1:])
1230 |
1231 | if args.test or args.verbose_test:
1232 | import doctest
1233 | selfmod = sys.modules[__name__]
1234 | if doctest.testmod(selfmod, verbose=args.verbose_test or None).failed:
1235 | sys.exit(1)
1236 |
1237 | bindings = dict(x.split('=', 1) for x in args.defines)
1238 | if args.file == '-':
1239 | ast = parse_template('stdin', sys.stdin.read())
1240 | else:
1241 | with io.open(os.path.normpath(args.file), 'r', encoding='utf-8') as f:
1242 | ast = parse_template(args.file, f.read())
1243 | if args.dump:
1244 | print(ast)
1245 | # Allow the template to open files and import .py files relative to its own
1246 | # directory
1247 | os.chdir(os.path.dirname(os.path.abspath(args.file)))
1248 | sys.path = ['.'] + sys.path
1249 |
1250 | if args.target == '-':
1251 | sys.stdout.write(execute_template(ast, args.line_directive, **bindings))
1252 | else:
1253 | with io.open(args.target, 'w', encoding='utf-8', newline='\n') as f:
1254 | f.write(execute_template(ast, args.line_directive, **bindings))
1255 |
1256 |
1257 | if __name__ == '__main__':
1258 | main()
1259 |
--------------------------------------------------------------------------------