import AuthenticationServices import SwiftUI import Foundation enum OAuth2 { enum Error: Swift.Error { case unknown case invalidAuthorizationURL case invalidCallbackURL case invalidRedirectURI case invalidScopes(Set) } enum TokenType: Codable, LosslessStringConvertible { case bearer case unknown(String) init?(_ description: String) { self = switch description { case "bearer": .bearer default: .unknown(description) } } var description: String { switch self { case .bearer: "bearer" case .unknown(let type): type } } } enum GrantType: Codable, LosslessStringConvertible { case authorizationCode case unknown(String) init?(_ description: String) { self = switch description { case "authorization_code": .authorizationCode default: .unknown(description) } } var description: String { switch self { case .authorizationCode: "authorization_code" case .unknown(let type): type } } } enum ResponseType: Codable, LosslessStringConvertible { case code case idToken case unknown(String) init?(_ description: String) { self = switch description { case "code": .code case "id_token": .idToken default: .unknown(description) } } var description: String { switch self { case .code: "code" case .idToken: "id_token" case .unknown(let type): type } } } fileprivate struct AccessTokenRequest: Codable { var clientID: String var grantType: GrantType var code: String? var redirectURI: URL? } struct AccessTokenResponse: Codable { var accessToken: String var tokenType: TokenType var expiresIn: Double? var refreshToken: String? } fileprivate struct CodeResponse: Codable { var code: String var state: String? } fileprivate static var decoder: JSONDecoder { let decoder = JSONDecoder() decoder.keyDecodingStrategy = .convertFromSnakeCase return decoder } fileprivate static var encoder: JSONEncoder { let encoder = JSONEncoder() encoder.keyEncodingStrategy = .convertToSnakeCase return encoder } } enum OpenID { struct Configuration: Codable { var issuer: URL var authorizationEndpoint: URL var tokenEndpoint: URL var userinfoEndpoint: URL var scopesSupported: Set static func load(from issuerURL: URL) async throws -> Self { let configurationURL = issuerURL .appending(component: ".well-known") .appending(component: "openid-configuration") let (data, _) = try await URLSession.shared.data(from: configurationURL) return try OAuth2.decoder.decode(Self.self, from: data) } } struct Session { var authorizationEndpoint: URL var tokenEndpoint: URL var redirectURI: URL var responseType = OAuth2.ResponseType.code var scopes: Set var clientID: String init(issuerURL: URL, redirectURI: URL, scopes: Set, clientID: String) async throws { async let configuration = Configuration.load(from: issuerURL) try await self.init( configuration: configuration, redirectURI: redirectURI, scopes: scopes, clientID: clientID ) } init(configuration: Configuration, redirectURI: URL, scopes: Set, clientID: String) throws { guard scopes.isSubset(of: configuration.scopesSupported) else { throw OAuth2.Error.invalidScopes(scopes.subtracting(configuration.scopesSupported)) } self.authorizationEndpoint = configuration.authorizationEndpoint self.tokenEndpoint = configuration.tokenEndpoint self.redirectURI = redirectURI self.scopes = scopes self.clientID = clientID } private var authorizationURL: URL { get throws { var queryItems: [URLQueryItem] = [ .init(name: "client_id", value: clientID), .init(name: "response_type", value: responseType.description), .init(name: "redirect_uri", value: redirectURI.absoluteString), ] if !scopes.isEmpty { queryItems.append(.init(name: "scope", value: scopes.joined(separator: ","))) } guard var components = URLComponents(url: authorizationEndpoint, resolvingAgainstBaseURL: false) else { throw OAuth2.Error.invalidAuthorizationURL } components.queryItems = queryItems guard let authorizationURL = components.url else { throw OAuth2.Error.invalidAuthorizationURL } return authorizationURL } } private func handle(callbackURL: URL) async throws -> OAuth2.AccessTokenResponse { switch responseType { case .code: guard let components = URLComponents(url: callbackURL, resolvingAgainstBaseURL: false) else { throw OAuth2.Error.invalidCallbackURL } return try await handle(response: try components.decode(OAuth2.CodeResponse.self)) default: throw OAuth2.Error.invalidCallbackURL } } private func handle(response: OAuth2.CodeResponse) async throws -> OAuth2.AccessTokenResponse { let body = OAuth2.AccessTokenRequest(clientID: clientID, grantType: .authorizationCode, code: response.code) var request = URLRequest(url: tokenEndpoint) request.httpMethod = "POST" request.httpBody = try OAuth2.encoder.encode(body) let session = URLSession(configuration: .ephemeral) let (data, _) = try await session.data(for: request) let response = try OAuth2.decoder.decode(OAuth2.AccessTokenResponse.self, from: data) return response } func authorize( configure: (ASWebAuthenticationSession) -> Void ) async throws -> OAuth2.AccessTokenResponse { let authorizationURL = try authorizationURL let callbackURL = try await ASWebAuthenticationSession.start( url: authorizationURL, redirectURI: redirectURI, configure: configure ) return try await handle(callbackURL: callbackURL) } func authorize(_ session: WebAuthenticationSession) async throws -> OAuth2.AccessTokenResponse { let authorizationURL = try authorizationURL let callbackURL = try await session.start( url: authorizationURL, redirectURI: redirectURI ) return try await handle(callbackURL: callbackURL) } } } extension WebAuthenticationSession { func start(url: URL, redirectURI: URL) async throws -> URL { if #available(iOS 17.4, macOS 14.4, tvOS 17.4, watchOS 10.4, *) { return try await authenticate( using: url, callback: try ASWebAuthenticationSession.callback(for: redirectURI), additionalHeaderFields: [:] ) } else { let callbackURLScheme = try ASWebAuthenticationSession.callbackURLScheme(for: redirectURI) ?? "" return try await authenticate(using: url, callbackURLScheme: callbackURLScheme) } } } extension ASWebAuthenticationSession { static func start(url: URL, redirectURI: URL, configure: (ASWebAuthenticationSession) -> Void) async throws -> URL { try await withUnsafeThrowingContinuation { continuation in do { let session: ASWebAuthenticationSession if #available(iOS 17.4, macOS 14.4, tvOS 17.4, watchOS 10.4, *) { session = ASWebAuthenticationSession( url: url, callback: try callback(for: redirectURI), completionHandler: completionHandler(for: continuation) ) } else { session = ASWebAuthenticationSession( url: url, callbackURLScheme: try callbackURLScheme(for: redirectURI), completionHandler: completionHandler(for: continuation) ) } configure(session) session.start() } catch { continuation.resume(throwing: error) } } } private static func completionHandler(for continuation: UnsafeContinuation) -> CompletionHandler { return { url, error in if let url { continuation.resume(returning: url) } else { continuation.resume(throwing: error ?? OAuth2.Error.unknown) } } } } extension ASWebAuthenticationSession { @available(iOS 17.4, macOS 14.4, tvOS 17.4, watchOS 10.4, *) fileprivate static func callback(for redirectURI: URL) throws -> Callback { switch redirectURI.scheme { case "https": guard let host = redirectURI.host else { throw OAuth2.Error.invalidRedirectURI } return .https(host: host, path: redirectURI.path) case "http": throw OAuth2.Error.invalidRedirectURI case .some(let scheme): return .customScheme(scheme) case .none: throw OAuth2.Error.invalidRedirectURI } } fileprivate static func callbackURLScheme(for redirectURI: URL) throws -> String? { switch redirectURI.scheme { case "http", .none: return nil case "https": #if os(macOS) if let host = url.host, let associatedDomains = try? Task.current.associatedDomains, !associatedDomains.contains(host) { throw OAuth2.Error.invalidCallbackURL } #endif return "https" case .some(let scheme): return scheme } } } extension URLComponents { fileprivate func decode(_ type: T.Type) throws -> T { guard let queryItems else { throw DecodingError.valueNotFound( T.self, .init(codingPath: [], debugDescription: "Missing query items") ) } let data = try JSONEncoder().encode(try queryItems.values) return try JSONDecoder().decode(T.self, from: data) } } extension Sequence where Element == URLQueryItem { fileprivate var values: [String: String?] { get throws { try Dictionary(map { ($0.name, $0.value) }) { _, _ in throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "Duplicate query items")) } } } } #if os(macOS) import Security private struct Task { enum Error: Swift.Error { case unknown } static var current: Self { get throws { guard let task = SecTaskCreateFromSelf(nil) else { throw Error.unknown } return Self(task: task) } } var task: SecTask var associatedDomains: [String] { get throws { var error: Unmanaged? let value = SecTaskCopyValueForEntitlement( task, "com.apple.developer.associated-domains" as CFString, &error ) if let error = error?.takeRetainedValue() { throw error } return value as! [String] // swiftlint:disable:this force_cast } } } #endif