diff --git a/Sources/D1Kit/D1Client.swift b/Sources/D1Kit/D1Client.swift index b6628b2..c75c27f 100644 --- a/Sources/D1Kit/D1Client.swift +++ b/Sources/D1Kit/D1Client.swift @@ -2,7 +2,7 @@ import Foundation import HTTPTypes public struct D1Client: Sendable { - public init(httpClient: HTTPClientProtocol, accountID: String, apiToken: String) { + public init(httpClient: any HTTPClientProtocol, accountID: String, apiToken: String) { precondition(accountID.isASCII) precondition(apiToken.isASCII) self.httpClient = httpClient diff --git a/Sources/D1Kit/D1Database.swift b/Sources/D1Kit/D1Database.swift index 070dd02..137793e 100644 --- a/Sources/D1Kit/D1Database.swift +++ b/Sources/D1Kit/D1Database.swift @@ -11,6 +11,8 @@ public struct D1Database: Sendable { public var client: D1Client public var databaseID: String + public var encodingOptions: D1ParameterEncodingOptions = .init() + public var dateDecodingStrategy: JSONDecoder.DateDecodingStrategy = .secondsSince1970 private func databaseURL() -> URL { return URL(string: "https://api.cloudflare.com/client/v4/accounts/\(client.accountID)/d1/database/\(databaseID)")! @@ -37,7 +39,7 @@ public struct D1Database: Sendable { binds: [any D1ParameterBindable], as rowType: D.Type ) async throws -> [D] { - let params = binds.map { $0.encodeToD1Parameter() } + let params = binds.map { $0.encodeToD1Parameter(options: encodingOptions) } return try await _query(query, params: params, as: rowType) } @@ -45,7 +47,11 @@ public struct D1Database: Sendable { _ query: QueryString, as rowType: D.Type ) async throws -> [D] { - return try await _query(query.query, params: query.params, as: rowType) + return try await _query( + query.query, + params: query.params.map({ $0.encodeToD1Parameter(options: encodingOptions) }), + as: rowType + ) } private struct Empty: Decodable {} @@ -55,12 +61,16 @@ public struct D1Database: Sendable { binds: repeat each B ) async throws { var params: [String] = [] - repeat params.append((each binds).encodeToD1Parameter()) + repeat params.append((each binds).encodeToD1Parameter(options: encodingOptions)) _ = try await _query(query, params: params, as: Empty.self) } public func query(_ query: QueryString) async throws { - _ = try await _query(query.query, params: query.params, as: Empty.self) + _ = try await _query( + query.query, + params: query.params.map({ $0.encodeToD1Parameter(options: encodingOptions) }), + as: Empty.self + ) } private func _query( @@ -83,14 +93,7 @@ public struct D1Database: Sendable { // print(String(data: body, encoding: .utf8) ?? "") let decoder = JSONDecoder() - decoder.dateDecodingStrategy = .custom { decoder in - let c = try decoder.singleValueContainer() - let string = try c.decode(String.self) - guard let date = DateFormatter.sqlite.date(from: string) else { - throw DecodingError.dataCorruptedError(in: c, debugDescription: "\(string) is bad format.") - } - return date - } + decoder.dateDecodingStrategy = dateDecodingStrategy let responseBody = try decoder.decode(QueryResponse.self, from: body) switch response.status { case .ok: diff --git a/Sources/D1Kit/D1ParameterBindable.swift b/Sources/D1Kit/D1ParameterBindable.swift index 8d8324c..147c371 100644 --- a/Sources/D1Kit/D1ParameterBindable.swift +++ b/Sources/D1Kit/D1ParameterBindable.swift @@ -1,23 +1,34 @@ import Foundation public protocol D1ParameterBindable: Sendable { - func encodeToD1Parameter() -> String + func encodeToD1Parameter(options: D1ParameterEncodingOptions) -> String } extension String: D1ParameterBindable { - public func encodeToD1Parameter() -> String { + public func encodeToD1Parameter(options: D1ParameterEncodingOptions) -> String { self } } extension Substring: D1ParameterBindable { - public func encodeToD1Parameter() -> String { + public func encodeToD1Parameter(options: D1ParameterEncodingOptions) -> String { String(self) } } extension Date: D1ParameterBindable { - public func encodeToD1Parameter() -> String { - DateFormatter.sqlite.string(from: self) + public func encodeToD1Parameter(options: D1ParameterEncodingOptions) -> String { + switch options.dateEncodingStrategy { + case .secondsSince1970: + return Int(timeIntervalSince1970).description + case .millisecondsSince1970: + return Int(timeIntervalSince1970 * 1000).description + case .iso8601: + return ISO8601DateFormatter().string(from: self) + case .formatted(let formatter): + return formatter.string(from: self) + case .custom(let custom): + return custom(self, options) + } } } diff --git a/Sources/D1Kit/D1ParameterEncodingOptions.swift b/Sources/D1Kit/D1ParameterEncodingOptions.swift new file mode 100644 index 0000000..f27ef2c --- /dev/null +++ b/Sources/D1Kit/D1ParameterEncodingOptions.swift @@ -0,0 +1,47 @@ +import Foundation + +public protocol D1ParameterEncodingOptionKey { + associatedtype Value: Sendable + static var defaultValue: Self.Value { get } +} + +public struct D1ParameterEncodingOptions: Sendable { + private var storage: [ObjectIdentifier: any Sendable] = [:] + + public init() {} + + public subscript(key: Key.Type) -> Key.Value { + get { + let i = ObjectIdentifier(key) + if let value = storage[i] as? Key.Value { + return value + } + return Key.defaultValue + } + set { + let i = ObjectIdentifier(key) + storage[i] = newValue + } + } +} + +public enum D1DateEncodingStrategy { + case secondsSince1970 + case millisecondsSince1970 + case iso8601 + case formatted(DateFormatter) + @preconcurrency case custom(@Sendable (Date, D1ParameterEncodingOptions) -> String) +} + +public struct D1DateEncodingStrategyKey: D1ParameterEncodingOptionKey { + public static var defaultValue: D1DateEncodingStrategy { + .secondsSince1970 + } +} + +extension D1ParameterEncodingOptions { + public var dateEncodingStrategy: D1DateEncodingStrategy { + get { self[D1DateEncodingStrategyKey.self] } + set { self[D1DateEncodingStrategyKey.self] = newValue } + } +} diff --git a/Sources/D1Kit/QueryString.swift b/Sources/D1Kit/QueryString.swift index f47310a..5aa247f 100644 --- a/Sources/D1Kit/QueryString.swift +++ b/Sources/D1Kit/QueryString.swift @@ -4,7 +4,7 @@ public struct QueryString { @usableFromInline var query: String @usableFromInline - var params: [String] = [] + var params: [any D1ParameterBindable] = [] @inlinable public init(_ string: some StringProtocol) { @@ -55,17 +55,17 @@ extension QueryString: StringInterpolationProtocol { } @inlinable - public mutating func appendInterpolation(bind value: some D1ParameterBindable) { + public mutating func appendInterpolation(bind value: any D1ParameterBindable) { self.query.append("?") - self.params.append(value.encodeToD1Parameter()) + self.params.append(value) } @inlinable - public mutating func appendInterpolation(binds values: [some D1ParameterBindable]) { + public mutating func appendInterpolation(binds values: [any D1ParameterBindable]) { self.query.append("(") self.query.append([String](repeating: "?", count: values.count).joined(separator: ",")) self.query.append(")") - self.params.append(contentsOf: values.map { $0.encodeToD1Parameter() }) + self.params.append(contentsOf: values) } @inlinable diff --git a/Tests/D1KitTests/D1KitTests.swift b/Tests/D1KitTests/D1KitTests.swift index 42303b8..bb7824c 100644 --- a/Tests/D1KitTests/D1KitTests.swift +++ b/Tests/D1KitTests/D1KitTests.swift @@ -28,7 +28,7 @@ final class D1KitTests: XCTestCase { SELECT 1 as "intValue" , 'Hello, world!' as "textValue" - , CURRENT_TIMESTAMP as "dateValue" + , unixepoch(CURRENT_TIMESTAMP) as "dateValue" """, as: Row.self).first if let test { XCTAssertEqual(test.intValue, 1) @@ -51,7 +51,7 @@ final class D1KitTests: XCTestCase { SELECT cast(? as integer) as "intValue" , ? as "textValue" - , ? as "dateValue" + , cast(? as integer) as "dateValue" """, binds: [String(42), "swift", now], as: Row.self).first @@ -81,7 +81,7 @@ final class D1KitTests: XCTestCase { , \(literal: 42) as "intValue" , \(literal: 42.195) as "doubleValue" , \(bind: "swift") as "textValue" - , \(bind: now) as "dateValue" + , cast(\(bind: now) as integer) as "dateValue" FROM cte WHERE @@ -104,4 +104,75 @@ final class D1KitTests: XCTestCase { PRAGMA quick_check(0) """) } + + func testFormatCheck() async throws { + struct Row: Decodable { + var bindedValueType: String + var timestamp: String + var unixepoch: Double + } + let test = try await db.query(""" + SELECT + typeof(\(bind: "swift")) as "bindedValueType" + , CURRENT_TIMESTAMP as timestamp + , unixepoch(CURRENT_TIMESTAMP) as unixepoch + """, as: Row.self).first + + if let test { + XCTAssertEqual(test.bindedValueType, "text") + XCTAssertNotNil(DateFormatter.sqliteTimestamp.date(from: test.timestamp)) + XCTAssertEqual(test.unixepoch.remainder(dividingBy: 1), 0.0, accuracy: 0.0) + } else { + XCTFail() + } + } + + func testDateCodingStrategy() async throws { + struct Row: Decodable { + var now: Date + } + + let now = Date(timeIntervalSince1970: floor(Date().timeIntervalSince1970)) + + var db = db! + + db.encodingOptions.dateEncodingStrategy = .secondsSince1970 + db.dateDecodingStrategy = .secondsSince1970 + var test = try await db.query(""" + SELECT + cast(\(bind: now) as integer) as now + """, as: Row.self).first + + if let test { + XCTAssertEqual(test.now, now) + } else { + XCTFail() + } + + db.encodingOptions.dateEncodingStrategy = .millisecondsSince1970 + db.dateDecodingStrategy = .millisecondsSince1970 + test = try await db.query(""" + SELECT + cast(\(bind: now) as integer) as now + """, as: Row.self).first + + if let test { + XCTAssertEqual(test.now, now) + } else { + XCTFail() + } + + db.encodingOptions.dateEncodingStrategy = .iso8601 + db.dateDecodingStrategy = .iso8601 + test = try await db.query(""" + SELECT + \(bind: now) as now + """, as: Row.self).first + + if let test { + XCTAssertEqual(test.now, now) + } else { + XCTFail() + } + } } diff --git a/Sources/D1Kit/SQLiteDateFormatter.swift b/Tests/D1KitTests/SQLiteDateFormatter.swift similarity index 95% rename from Sources/D1Kit/SQLiteDateFormatter.swift rename to Tests/D1KitTests/SQLiteDateFormatter.swift index 6f9506b..4ba129d 100644 --- a/Sources/D1Kit/SQLiteDateFormatter.swift +++ b/Tests/D1KitTests/SQLiteDateFormatter.swift @@ -1,7 +1,7 @@ import Foundation extension DateFormatter { - @ThreadLocal static var sqlite: DateFormatter = { + @ThreadLocal static var sqliteTimestamp: DateFormatter = { let f = DateFormatter() f.dateFormat = "yyyy-MM-dd HH:mm:ss" f.timeZone = .init(secondsFromGMT: 0)