364 lines
12 KiB
Swift
364 lines
12 KiB
Swift
import AuthenticationServices
|
|
import SwiftUI
|
|
import Foundation
|
|
|
|
enum OAuth2 {
|
|
enum Error: Swift.Error {
|
|
case unknown
|
|
case invalidAuthorizationURL
|
|
case invalidCallbackURL
|
|
case invalidRedirectURI
|
|
case invalidScopes(Set<String>)
|
|
}
|
|
|
|
enum TokenType: Codable, LosslessStringConvertible {
|
|
case bearer
|
|
case unknown(String)
|
|
|
|
init?(_ description: String) {
|
|
self = switch description {
|
|
case "bearer": .bearer
|
|
default: .unknown(description)
|
|
}
|
|
}
|
|
|
|
var description: String {
|
|
switch self {
|
|
case .bearer: "bearer"
|
|
case .unknown(let type): type
|
|
}
|
|
}
|
|
}
|
|
|
|
enum GrantType: Codable, LosslessStringConvertible {
|
|
case authorizationCode
|
|
case unknown(String)
|
|
|
|
init?(_ description: String) {
|
|
self = switch description {
|
|
case "authorization_code": .authorizationCode
|
|
default: .unknown(description)
|
|
}
|
|
}
|
|
|
|
var description: String {
|
|
switch self {
|
|
case .authorizationCode: "authorization_code"
|
|
case .unknown(let type): type
|
|
}
|
|
}
|
|
}
|
|
|
|
enum ResponseType: Codable, LosslessStringConvertible {
|
|
case code
|
|
case idToken
|
|
case unknown(String)
|
|
|
|
init?(_ description: String) {
|
|
self = switch description {
|
|
case "code": .code
|
|
case "id_token": .idToken
|
|
default: .unknown(description)
|
|
}
|
|
}
|
|
|
|
var description: String {
|
|
switch self {
|
|
case .code: "code"
|
|
case .idToken: "id_token"
|
|
case .unknown(let type): type
|
|
}
|
|
}
|
|
}
|
|
|
|
fileprivate struct AccessTokenRequest: Codable {
|
|
var clientID: String
|
|
var grantType: GrantType
|
|
var code: String?
|
|
var redirectURI: URL?
|
|
}
|
|
|
|
struct AccessTokenResponse: Codable {
|
|
var accessToken: String
|
|
var tokenType: TokenType
|
|
var expiresIn: Double?
|
|
var refreshToken: String?
|
|
}
|
|
|
|
fileprivate struct CodeResponse: Codable {
|
|
var code: String
|
|
var state: String?
|
|
}
|
|
|
|
fileprivate static var decoder: JSONDecoder {
|
|
let decoder = JSONDecoder()
|
|
decoder.keyDecodingStrategy = .convertFromSnakeCase
|
|
return decoder
|
|
}
|
|
|
|
fileprivate static var encoder: JSONEncoder {
|
|
let encoder = JSONEncoder()
|
|
encoder.keyEncodingStrategy = .convertToSnakeCase
|
|
return encoder
|
|
}
|
|
}
|
|
|
|
enum OpenID {
|
|
struct Configuration: Codable {
|
|
var issuer: URL
|
|
var authorizationEndpoint: URL
|
|
var tokenEndpoint: URL
|
|
var userinfoEndpoint: URL
|
|
var scopesSupported: Set<String>
|
|
|
|
static func load(from issuerURL: URL) async throws -> Self {
|
|
let configurationURL = issuerURL
|
|
.appending(component: ".well-known")
|
|
.appending(component: "openid-configuration")
|
|
let (data, _) = try await URLSession.shared.data(from: configurationURL)
|
|
return try OAuth2.decoder.decode(Self.self, from: data)
|
|
}
|
|
}
|
|
|
|
struct Session {
|
|
var authorizationEndpoint: URL
|
|
var tokenEndpoint: URL
|
|
var redirectURI: URL
|
|
var responseType = OAuth2.ResponseType.code
|
|
var scopes: Set<String>
|
|
var clientID: String
|
|
|
|
init(issuerURL: URL, redirectURI: URL, scopes: Set<String>, clientID: String) async throws {
|
|
async let configuration = Configuration.load(from: issuerURL)
|
|
try await self.init(
|
|
configuration: configuration,
|
|
redirectURI: redirectURI,
|
|
scopes: scopes,
|
|
clientID: clientID
|
|
)
|
|
}
|
|
|
|
init(configuration: Configuration, redirectURI: URL, scopes: Set<String>, clientID: String) throws {
|
|
guard scopes.isSubset(of: configuration.scopesSupported) else {
|
|
throw OAuth2.Error.invalidScopes(scopes.subtracting(configuration.scopesSupported))
|
|
}
|
|
|
|
self.authorizationEndpoint = configuration.authorizationEndpoint
|
|
self.tokenEndpoint = configuration.tokenEndpoint
|
|
self.redirectURI = redirectURI
|
|
self.scopes = scopes
|
|
self.clientID = clientID
|
|
}
|
|
|
|
private var authorizationURL: URL {
|
|
get throws {
|
|
var queryItems: [URLQueryItem] = [
|
|
.init(name: "client_id", value: clientID),
|
|
.init(name: "response_type", value: responseType.description),
|
|
.init(name: "redirect_uri", value: redirectURI.absoluteString),
|
|
]
|
|
if !scopes.isEmpty {
|
|
queryItems.append(.init(name: "scope", value: scopes.joined(separator: ",")))
|
|
}
|
|
guard var components = URLComponents(url: authorizationEndpoint, resolvingAgainstBaseURL: false) else {
|
|
throw OAuth2.Error.invalidAuthorizationURL
|
|
}
|
|
components.queryItems = queryItems
|
|
guard let authorizationURL = components.url else { throw OAuth2.Error.invalidAuthorizationURL }
|
|
return authorizationURL
|
|
}
|
|
}
|
|
|
|
private func handle(callbackURL: URL) async throws -> OAuth2.AccessTokenResponse {
|
|
switch responseType {
|
|
case .code:
|
|
guard let components = URLComponents(url: callbackURL, resolvingAgainstBaseURL: false) else {
|
|
throw OAuth2.Error.invalidCallbackURL
|
|
}
|
|
return try await handle(response: try components.decode(OAuth2.CodeResponse.self))
|
|
default:
|
|
throw OAuth2.Error.invalidCallbackURL
|
|
}
|
|
}
|
|
|
|
private func handle(response: OAuth2.CodeResponse) async throws -> OAuth2.AccessTokenResponse {
|
|
let body = OAuth2.AccessTokenRequest(clientID: clientID, grantType: .authorizationCode, code: response.code)
|
|
|
|
var request = URLRequest(url: tokenEndpoint)
|
|
request.httpMethod = "POST"
|
|
request.httpBody = try OAuth2.encoder.encode(body)
|
|
|
|
let session = URLSession(configuration: .ephemeral)
|
|
let (data, _) = try await session.data(for: request)
|
|
let response = try OAuth2.decoder.decode(OAuth2.AccessTokenResponse.self, from: data)
|
|
|
|
return response
|
|
}
|
|
|
|
func authorize(
|
|
configure: (ASWebAuthenticationSession) -> Void
|
|
) async throws -> OAuth2.AccessTokenResponse {
|
|
let authorizationURL = try authorizationURL
|
|
let callbackURL = try await ASWebAuthenticationSession.start(
|
|
url: authorizationURL,
|
|
redirectURI: redirectURI,
|
|
configure: configure
|
|
)
|
|
return try await handle(callbackURL: callbackURL)
|
|
}
|
|
|
|
func authorize(_ session: WebAuthenticationSession) async throws -> OAuth2.AccessTokenResponse {
|
|
let authorizationURL = try authorizationURL
|
|
let callbackURL = try await session.start(
|
|
url: authorizationURL,
|
|
redirectURI: redirectURI
|
|
)
|
|
return try await handle(callbackURL: callbackURL)
|
|
}
|
|
}
|
|
}
|
|
|
|
extension WebAuthenticationSession {
|
|
func start(url: URL, redirectURI: URL) async throws -> URL {
|
|
if #available(iOS 17.4, macOS 14.4, tvOS 17.4, watchOS 10.4, *) {
|
|
return try await authenticate(
|
|
using: url,
|
|
callback: try ASWebAuthenticationSession.callback(for: redirectURI),
|
|
additionalHeaderFields: [:]
|
|
)
|
|
} else {
|
|
let callbackURLScheme = try ASWebAuthenticationSession.callbackURLScheme(for: redirectURI) ?? ""
|
|
return try await authenticate(using: url, callbackURLScheme: callbackURLScheme)
|
|
}
|
|
}
|
|
}
|
|
|
|
extension ASWebAuthenticationSession {
|
|
static func start(url: URL, redirectURI: URL, configure: (ASWebAuthenticationSession) -> Void) async throws -> URL {
|
|
try await withUnsafeThrowingContinuation { continuation in
|
|
do {
|
|
let session: ASWebAuthenticationSession
|
|
if #available(iOS 17.4, macOS 14.4, tvOS 17.4, watchOS 10.4, *) {
|
|
session = ASWebAuthenticationSession(
|
|
url: url,
|
|
callback: try callback(for: redirectURI),
|
|
completionHandler: completionHandler(for: continuation)
|
|
)
|
|
} else {
|
|
session = ASWebAuthenticationSession(
|
|
url: url,
|
|
callbackURLScheme: try callbackURLScheme(for: redirectURI),
|
|
completionHandler: completionHandler(for: continuation)
|
|
)
|
|
}
|
|
configure(session)
|
|
session.start()
|
|
} catch {
|
|
continuation.resume(throwing: error)
|
|
}
|
|
}
|
|
}
|
|
|
|
private static func completionHandler(for continuation: UnsafeContinuation<URL, Error>) -> CompletionHandler {
|
|
return { url, error in
|
|
if let url {
|
|
continuation.resume(returning: url)
|
|
} else {
|
|
continuation.resume(throwing: error ?? OAuth2.Error.unknown)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
extension ASWebAuthenticationSession {
|
|
@available(iOS 17.4, macOS 14.4, tvOS 17.4, watchOS 10.4, *)
|
|
fileprivate static func callback(for redirectURI: URL) throws -> Callback {
|
|
switch redirectURI.scheme {
|
|
case "https":
|
|
guard let host = redirectURI.host else { throw OAuth2.Error.invalidRedirectURI }
|
|
return .https(host: host, path: redirectURI.path)
|
|
case "http":
|
|
throw OAuth2.Error.invalidRedirectURI
|
|
case .some(let scheme):
|
|
return .customScheme(scheme)
|
|
case .none:
|
|
throw OAuth2.Error.invalidRedirectURI
|
|
}
|
|
}
|
|
|
|
fileprivate static func callbackURLScheme(for redirectURI: URL) throws -> String? {
|
|
switch redirectURI.scheme {
|
|
case "http", .none:
|
|
return nil
|
|
case "https":
|
|
#if os(macOS)
|
|
if
|
|
let host = url.host,
|
|
let associatedDomains = try? Task.current.associatedDomains,
|
|
!associatedDomains.contains(host) {
|
|
throw OAuth2.Error.invalidCallbackURL
|
|
}
|
|
#endif
|
|
return "https"
|
|
case .some(let scheme):
|
|
return scheme
|
|
}
|
|
}
|
|
}
|
|
|
|
extension URLComponents {
|
|
fileprivate func decode<T: Decodable>(_ type: T.Type) throws -> T {
|
|
guard let queryItems else {
|
|
throw DecodingError.valueNotFound(
|
|
T.self,
|
|
.init(codingPath: [], debugDescription: "Missing query items")
|
|
)
|
|
}
|
|
let data = try JSONEncoder().encode(try queryItems.values)
|
|
return try JSONDecoder().decode(T.self, from: data)
|
|
}
|
|
}
|
|
|
|
extension Sequence where Element == URLQueryItem {
|
|
fileprivate var values: [String: String?] {
|
|
get throws {
|
|
try Dictionary(map { ($0.name, $0.value) }) { _, _ in
|
|
throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "Duplicate query items"))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#if os(macOS)
|
|
import Security
|
|
|
|
private struct Task {
|
|
enum Error: Swift.Error {
|
|
case unknown
|
|
}
|
|
|
|
static var current: Self {
|
|
get throws {
|
|
guard let task = SecTaskCreateFromSelf(nil) else { throw Error.unknown }
|
|
return Self(task: task)
|
|
}
|
|
}
|
|
|
|
var task: SecTask
|
|
|
|
var associatedDomains: [String] {
|
|
get throws {
|
|
var error: Unmanaged<CFError>?
|
|
let value = SecTaskCopyValueForEntitlement(
|
|
task,
|
|
"com.apple.developer.associated-domains" as CFString,
|
|
&error
|
|
)
|
|
if let error = error?.takeRetainedValue() {
|
|
throw error
|
|
}
|
|
return value as! [String] // swiftlint:disable:this force_cast
|
|
}
|
|
}
|
|
}
|
|
#endif
|