Skip to content

Commit

Permalink
Properly fulfill write promises
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM committed Aug 8, 2024
1 parent d18b137 commit 7aaae6e
Show file tree
Hide file tree
Showing 13 changed files with 300 additions and 231 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ struct AuthenticationStateMachine {
}

enum Action {
case sendStartupMessage(AuthContext)
case sendPassword(PasswordAuthencationMode, AuthContext)
case sendSaslInitialResponse(name: String, initialResponse: [UInt8])
case sendSaslResponse([UInt8])
case sendStartupMessage(AuthContext, promise: EventLoopPromise<Void>?)
case sendPassword(PasswordAuthencationMode, AuthContext, promise: EventLoopPromise<Void>?)
case sendSaslInitialResponse(name: String, initialResponse: [UInt8], promise: EventLoopPromise<Void>?)
case sendSaslResponse([UInt8], promise: EventLoopPromise<Void>?)
case wait
case authenticated

Expand All @@ -34,12 +34,12 @@ struct AuthenticationStateMachine {
self.state = .initialized
}

mutating func start() -> Action {
mutating func start(_ promise: EventLoopPromise<Void>?) -> Action {
guard case .initialized = self.state else {
preconditionFailure("Unexpected state")
}
self.state = .startupMessageSent
return .sendStartupMessage(self.authContext)
return .sendStartupMessage(self.authContext, promise: promise)
}

mutating func authenticationMessageReceived(_ message: PostgresBackendMessage.Authentication) -> Action {
Expand All @@ -54,10 +54,10 @@ struct AuthenticationStateMachine {
return self.setAndFireError(PSQLError(code: .authMechanismRequiresPassword))
}
self.state = .passwordAuthenticationSent
return .sendPassword(.md5(salt: salt), self.authContext)
return .sendPassword(.md5(salt: salt), self.authContext, promise: nil)
case .plaintext:
self.state = .passwordAuthenticationSent
return .sendPassword(.cleartext, authContext)
return .sendPassword(.cleartext, authContext, promise: nil)
case .kerberosV5:
return self.setAndFireError(.unsupportedAuthMechanism(.kerberosV5))
case .scmCredential:
Expand Down Expand Up @@ -89,7 +89,7 @@ struct AuthenticationStateMachine {
}

self.state = .saslInitialResponseSent(saslManager)
return .sendSaslInitialResponse(name: SASLMechanism.SCRAM.SHA256.name, initialResponse: output)
return .sendSaslInitialResponse(name: SASLMechanism.SCRAM.SHA256.name, initialResponse: output, promise: nil)
} catch {
return self.setAndFireError(.sasl(underlying: error))
}
Expand Down Expand Up @@ -122,7 +122,7 @@ struct AuthenticationStateMachine {
}

self.state = .saslChallengeResponseSent(saslManager)
return .sendSaslResponse(output)
return .sendSaslResponse(output, promise: nil)
} catch {
return self.setAndFireError(.sasl(underlying: error))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct CloseStateMachine {
}

enum Action {
case sendCloseSync(CloseTarget)
case sendCloseSync(CloseTarget, promise: EventLoopPromise<Void>?)
case succeedClose(CloseCommandContext)
case failClose(CloseCommandContext, with: PSQLError)

Expand All @@ -24,14 +24,14 @@ struct CloseStateMachine {
self.state = .initialized(closeContext)
}

mutating func start() -> Action {
mutating func start(_ promise: EventLoopPromise<Void>?) -> Action {
guard case .initialized(let closeContext) = self.state else {
preconditionFailure("Start should only be called, if the query has been initialized")
}

self.state = .closeSyncSent(closeContext)

return .sendCloseSync(closeContext.target)
return .sendCloseSync(closeContext.target, promise: promise)
}

mutating func closeCompletedReceived() -> Action {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ struct ConnectionStateMachine {

case read
case wait
case sendSSLRequest
case establishSSLConnection
case sendSSLRequest(EventLoopPromise<Void>?)
case establishSSLConnection(EventLoopPromise<Void>?)
case provideAuthenticationContext
case forwardNotificationToListeners(PostgresBackendMessage.NotificationResponse)
case fireEventReadyForQuery
Expand All @@ -77,16 +77,16 @@ struct ConnectionStateMachine {
case closeConnectionAndCleanup(CleanUpContext)

// Auth Actions
case sendStartupMessage(AuthContext)
case sendPasswordMessage(PasswordAuthencationMode, AuthContext)
case sendSaslInitialResponse(name: String, initialResponse: [UInt8])
case sendSaslResponse([UInt8])
case sendStartupMessage(AuthContext, promise: EventLoopPromise<Void>?)
case sendPasswordMessage(PasswordAuthencationMode, AuthContext, promise: EventLoopPromise<Void>?)
case sendSaslInitialResponse(name: String, initialResponse: [UInt8], promise: EventLoopPromise<Void>?)
case sendSaslResponse([UInt8], promise: EventLoopPromise<Void>?)

// Connection Actions

// --- general actions
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendBindExecuteSync(PSQLExecuteStatement)
case sendParseDescribeBindExecuteSync(PostgresQuery, promise: EventLoopPromise<Void>?)
case sendBindExecuteSync(PSQLExecuteStatement, promise: EventLoopPromise<Void>?)
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)

Expand All @@ -97,12 +97,12 @@ struct ConnectionStateMachine {
case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?)

// Prepare statement actions
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType], promise: EventLoopPromise<Void>?)
case succeedPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: RowDescription?)
case failPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: PSQLError, cleanupContext: CleanUpContext?)

// Close actions
case sendCloseSync(CloseTarget)
case sendCloseSync(CloseTarget, promise: EventLoopPromise<Void>?)
case succeedClose(CloseCommandContext)
case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?)
}
Expand Down Expand Up @@ -131,7 +131,7 @@ struct ConnectionStateMachine {
case require
}

mutating func connected(tls: TLSConfiguration) -> ConnectionAction {
mutating func connected(tls: TLSConfiguration, promise: EventLoopPromise<Void>?) -> ConnectionAction {
switch self.state {
case .initialized:
switch tls {
Expand All @@ -141,11 +141,11 @@ struct ConnectionStateMachine {

case .prefer:
self.state = .sslRequestSent(.prefer)
return .sendSSLRequest
return .sendSSLRequest(promise)

case .require:
self.state = .sslRequestSent(.require)
return .sendSSLRequest
return .sendSSLRequest(promise)
}

case .sslRequestSent,
Expand All @@ -164,8 +164,11 @@ struct ConnectionStateMachine {
}
}

mutating func provideAuthenticationContext(_ authContext: AuthContext) -> ConnectionAction {
self.startAuthentication(authContext)
mutating func provideAuthenticationContext(
_ authContext: AuthContext,
promise: EventLoopPromise<Void>?
) -> ConnectionAction {
self.startAuthentication(authContext, promise: promise)
}

mutating func gracefulClose(_ promise: EventLoopPromise<Void>?) -> ConnectionAction {
Expand Down Expand Up @@ -233,8 +236,8 @@ struct ConnectionStateMachine {
return self.closeConnectionAndCleanup(.receivedUnencryptedDataAfterSSLRequest)
}
self.state = .sslNegotiated
return .establishSSLConnection
return .establishSSLConnection(nil)

case .initialized,
.sslNegotiated,
.sslHandlerAdded,
Expand Down Expand Up @@ -583,14 +586,16 @@ struct ConnectionStateMachine {
}

switch task {
case .extendedQuery(let queryContext):
case .extendedQuery(let queryContext, let writePromise):
writePromise?.fail(psqlErrror) /// Use `cleanupContext` or not?
switch queryContext.query {
case .executeStatement(_, let promise), .unnamed(_, let promise):
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
case .prepareStatement(_, _, _, let promise):
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
}
case .closeCommand(let closeContext):
case .closeCommand(let closeContext, let writePromise):
writePromise?.fail(psqlErrror) /// Use `cleanupContext` or not?
return .failClose(closeContext, with: psqlErrror, cleanupContext: nil)
}
}
Expand Down Expand Up @@ -800,14 +805,17 @@ struct ConnectionStateMachine {

// MARK: - Private Methods -

private mutating func startAuthentication(_ authContext: AuthContext) -> ConnectionAction {
private mutating func startAuthentication(
_ authContext: AuthContext,
promise: EventLoopPromise<Void>?
) -> ConnectionAction {
guard case .waitingToStartAuthentication = self.state else {
preconditionFailure("Can only start authentication after connect or ssl establish")
}

self.state = .modifying // avoid CoW
var authState = AuthenticationStateMachine(authContext: authContext)
let action = authState.start()
let action = authState.start(promise)
self.state = .authenticating(authState)
return self.modify(with: action)
}
Expand Down Expand Up @@ -934,17 +942,17 @@ struct ConnectionStateMachine {
}

switch task {
case .extendedQuery(let queryContext):
case .extendedQuery(let queryContext, let promise):
self.state = .modifying // avoid CoW
var extendedQuery = ExtendedQueryStateMachine(queryContext: queryContext)
let action = extendedQuery.start()
let action = extendedQuery.start(promise)
self.state = .extendedQuery(extendedQuery, connectionContext)
return self.modify(with: action)

case .closeCommand(let closeContext):
case .closeCommand(let closeContext, let promise):
self.state = .modifying // avoid CoW
var closeStateMachine = CloseStateMachine(closeContext: closeContext)
let action = closeStateMachine.start()
let action = closeStateMachine.start(promise)
self.state = .closeCommand(closeStateMachine, connectionContext)
return self.modify(with: action)
}
Expand Down Expand Up @@ -1031,10 +1039,10 @@ extension ConnectionStateMachine {
extension ConnectionStateMachine {
mutating func modify(with action: ExtendedQueryStateMachine.Action) -> ConnectionStateMachine.ConnectionAction {
switch action {
case .sendParseDescribeBindExecuteSync(let query):
return .sendParseDescribeBindExecuteSync(query)
case .sendBindExecuteSync(let executeStatement):
return .sendBindExecuteSync(executeStatement)
case .sendParseDescribeBindExecuteSync(let query, let promise):
return .sendParseDescribeBindExecuteSync(query, promise: promise)
case .sendBindExecuteSync(let executeStatement, let promise):
return .sendBindExecuteSync(executeStatement, promise: promise)
case .failQuery(let requestContext, with: let error):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
Expand All @@ -1057,8 +1065,8 @@ extension ConnectionStateMachine {
return .read
case .wait:
return .wait
case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes):
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes, let promise):
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes, promise: promise)
case .succeedPreparedStatementCreation(let promise, with: let rowDescription):
return .succeedPreparedStatementCreation(promise, with: rowDescription)
case .failPreparedStatementCreation(let promise, with: let error):
Expand All @@ -1071,14 +1079,14 @@ extension ConnectionStateMachine {
extension ConnectionStateMachine {
mutating func modify(with action: AuthenticationStateMachine.Action) -> ConnectionStateMachine.ConnectionAction {
switch action {
case .sendStartupMessage(let authContext):
return .sendStartupMessage(authContext)
case .sendPassword(let mode, let authContext):
return .sendPasswordMessage(mode, authContext)
case .sendSaslInitialResponse(let name, let initialResponse):
return .sendSaslInitialResponse(name: name, initialResponse: initialResponse)
case .sendSaslResponse(let bytes):
return .sendSaslResponse(bytes)
case .sendStartupMessage(let authContext, let promise):
return .sendStartupMessage(authContext, promise: promise)
case .sendPassword(let mode, let authContext, let promise):
return .sendPasswordMessage(mode, authContext, promise: promise)
case .sendSaslInitialResponse(let name, let initialResponse, let promise):
return .sendSaslInitialResponse(name: name, initialResponse: initialResponse, promise: promise)
case .sendSaslResponse(let bytes, let promise):
return .sendSaslResponse(bytes, promise: promise)
case .authenticated:
self.state = .authenticated(nil, [:])
return .wait
Expand All @@ -1094,8 +1102,8 @@ extension ConnectionStateMachine {
extension ConnectionStateMachine {
mutating func modify(with action: CloseStateMachine.Action) -> ConnectionStateMachine.ConnectionAction {
switch action {
case .sendCloseSync(let sendClose):
return .sendCloseSync(sendClose)
case .sendCloseSync(let sendClose, let promise):
return .sendCloseSync(sendClose, promise: promise)
case .succeedClose(let closeContext):
return .succeedClose(closeContext)
case .failClose(let closeContext, with: let error):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ struct ExtendedQueryStateMachine {
}

enum Action {
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
case sendBindExecuteSync(PSQLExecuteStatement)
case sendParseDescribeBindExecuteSync(PostgresQuery, promise: EventLoopPromise<Void>?)
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType], promise: EventLoopPromise<Void>?)
case sendBindExecuteSync(PSQLExecuteStatement, promise: EventLoopPromise<Void>?)

// --- general actions
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError)
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
Expand Down Expand Up @@ -56,7 +56,7 @@ struct ExtendedQueryStateMachine {
self.state = .initialized(queryContext)
}

mutating func start() -> Action {
mutating func start(_ promise: EventLoopPromise<Void>?) -> Action {
guard case .initialized(let queryContext) = self.state else {
preconditionFailure("Start should only be called, if the query has been initialized")
}
Expand All @@ -65,7 +65,7 @@ struct ExtendedQueryStateMachine {
case .unnamed(let query, _):
return self.avoidingStateMachineCoW { state -> Action in
state = .messagesSent(queryContext)
return .sendParseDescribeBindExecuteSync(query)
return .sendParseDescribeBindExecuteSync(query, promise: promise)
}

case .executeStatement(let prepared, _):
Expand All @@ -76,13 +76,14 @@ struct ExtendedQueryStateMachine {
case .none:
state = .noDataMessageReceived(queryContext)
}
return .sendBindExecuteSync(prepared)
return .sendBindExecuteSync(prepared, promise: promise)
}

/// Not my code, but this is ignoring the last argument which is a promise? is that fine?
case .prepareStatement(let name, let query, let bindingDataTypes, _):
return self.avoidingStateMachineCoW { state -> Action in
state = .messagesSent(queryContext)
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes, promise: promise)
}
}
}
Expand Down
15 changes: 10 additions & 5 deletions Sources/PostgresNIO/New/NotificationListener.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ final class NotificationListener: @unchecked Sendable {
self.state = .closure(context, closure)
}

func startListeningSucceeded(handler: PostgresChannelHandler) {
func startListeningSucceeded(
handler: PostgresChannelHandler,
writePromise: EventLoopPromise<Void>?
) {
self.eventLoop.preconditionInEventLoop()
let handlerLoopBound = NIOLoopBound(handler, eventLoop: self.eventLoop)

Expand All @@ -56,26 +59,28 @@ final class NotificationListener: @unchecked Sendable {
switch reason {
case .cancelled:
eventLoop.execute {
handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID)
handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID, writePromise: nil)
}

case .finished:
break
writePromise?.succeed()

@unknown default:
break
writePromise?.succeed()
}
}
self.state = .streamListening(continuation)

let notificationSequence = PostgresNotificationSequence(base: stream)
checkedContinuation.resume(returning: notificationSequence)
writePromise?.succeed(())

case .streamListening, .done:
fatalError("Invalid state: \(self.state)")

case .closure:
break // ignore
writePromise?.succeed(())
// ignore
}
}

Expand Down
Loading

0 comments on commit 7aaae6e

Please sign in to comment.