Skip to content

Commit fcf29f7

Browse files
feat!: Passes init payload result
This automatically propagates the init payload result from the callback into the `onExecute` and `onSubscribe` closures. Since the init callback is usually used to determine authentication and authorization, this should be usable from our execution and subscription calls, and lifecycle is most easily managed within this package
1 parent 43a64cb commit fcf29f7

2 files changed

Lines changed: 87 additions & 38 deletions

File tree

Sources/GraphQLTransportWS/Server.swift

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,25 @@ import GraphQL
66
/// By default, there are no authorization checks
77
public class Server<
88
InitPayload: Equatable & Codable & Sendable,
9+
InitPayloadResult: Sendable,
910
SubscriptionSequenceType: AsyncSequence & Sendable
1011
>: @unchecked Sendable where
1112
SubscriptionSequenceType.Element == GraphQLResult
1213
{
1314
// We keep this weak because we strongly inject this object into the messenger callback
1415
weak var messenger: Messenger?
15-
16-
let onExecute: (GraphQLRequest) async throws -> GraphQLResult
17-
let onSubscribe: (GraphQLRequest) async throws -> SubscriptionSequenceType
18-
var auth: (InitPayload) async throws -> Void
16+
17+
let onInit: (InitPayload) async throws -> InitPayloadResult
18+
let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult
19+
let onSubscribe: (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType
1920

2021
var onExit: () async throws -> Void = {}
2122
var onMessage: (String) async throws -> Void = { _ in }
2223
var onOperationComplete: (String) async throws -> Void = { _ in }
2324
var onOperationError: (String, [Error]) async throws -> Void = { _, _ in }
2425

2526
var initialized = false
27+
var initResult: InitPayloadResult?
2628

2729
let decoder = JSONDecoder()
2830
let encoder = GraphQLJSONEncoder()
@@ -37,13 +39,14 @@ public class Server<
3739
/// - onSubscribe: Callback run during `start` resolution for streaming queries. Typically this is `API.subscribe`.
3840
public init(
3941
messenger: Messenger,
40-
onExecute: @escaping (GraphQLRequest) async throws -> GraphQLResult,
41-
onSubscribe: @escaping (GraphQLRequest) async throws -> SubscriptionSequenceType
42+
onInit: @escaping (InitPayload) async throws -> InitPayloadResult,
43+
onExecute: @escaping (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult,
44+
onSubscribe: @escaping (GraphQLRequest, InitPayloadResult) async throws -> SubscriptionSequenceType
4245
) {
4346
self.messenger = messenger
47+
self.onInit = onInit
4448
self.onExecute = onExecute
4549
self.onSubscribe = onSubscribe
46-
auth = { _ in }
4750

4851
messenger.onReceive { message in
4952
guard let messenger = self.messenger else { return }
@@ -99,13 +102,6 @@ public class Server<
99102
subscriptionTasks.values.forEach { $0.cancel() }
100103
}
101104

102-
/// Define a custom callback run during `connection_init` resolution that allows authorization using the `payload`.
103-
/// Throw from this closure to indicate that authorization has failed.
104-
/// - Parameter callback: The callback to assign
105-
public func auth(_ callback: @escaping (InitPayload) async throws -> Void) {
106-
auth = callback
107-
}
108-
109105
/// Define the callback run when the communication is shut down, either by the client or server
110106
/// - Parameter callback: The callback to assign
111107
public func onExit(_ callback: @escaping () -> Void) {
@@ -137,7 +133,7 @@ public class Server<
137133
}
138134

139135
do {
140-
try await auth(connectionInitRequest.payload)
136+
initResult = try await onInit(connectionInitRequest.payload)
141137
} catch {
142138
try await self.error(.unauthorized())
143139
return
@@ -148,7 +144,7 @@ public class Server<
148144
}
149145

150146
private func onSubscribe(_ subscribeRequest: SubscribeRequest) async throws {
151-
guard initialized else {
147+
guard initialized, let initResult else {
152148
try await error(.notInitialized())
153149
return
154150
}
@@ -171,7 +167,7 @@ public class Server<
171167
if isStreaming {
172168
subscriptionTasks[id] = Task {
173169
do {
174-
let stream = try await onSubscribe(graphQLRequest)
170+
let stream = try await onSubscribe(graphQLRequest, initResult)
175171
for try await event in stream {
176172
try Task.checkCancellation()
177173
try await self.sendNext(event, id: id)
@@ -186,7 +182,7 @@ public class Server<
186182
}
187183
} else {
188184
do {
189-
let result = try await onExecute(graphQLRequest)
185+
let result = try await onExecute(graphQLRequest, initResult)
190186
try await sendNext(result, id: id)
191187
try await sendComplete(id: id)
192188
} catch {

Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,42 +8,40 @@ import GraphQLTransportWS
88
class GraphqlTransportWSTests: XCTestCase {
99
var clientMessenger: TestMessenger!
1010
var serverMessenger: TestMessenger!
11-
var server: Server<TokenInitPayload, AsyncThrowingStream<GraphQLResult, Error>>!
12-
var context: TestContext!
1311
var subscribeReady: Bool! = false
12+
13+
let context = TestContext()
14+
let api = TestAPI()
1415

1516
override func setUp() {
1617
// Point the client and server at each other
1718
clientMessenger = TestMessenger()
1819
serverMessenger = TestMessenger()
1920
clientMessenger.other = serverMessenger
2021
serverMessenger.other = clientMessenger
22+
}
2123

22-
let api = TestAPI()
23-
let context = TestContext()
24-
25-
server = .init(
24+
/// Tests that trying to run methods before `connection_init` is not allowed
25+
func testInitialize() async throws {
26+
_ = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
2627
messenger: serverMessenger,
27-
onExecute: { graphQLRequest in
28-
try await api.execute(
28+
onInit: { _ in },
29+
onExecute: { graphQLRequest, _ in
30+
try await self.api.execute(
2931
request: graphQLRequest.query,
30-
context: context
32+
context: self.context
3133
)
3234
},
33-
onSubscribe: { graphQLRequest in
34-
let subscription = try await api.subscribe(
35+
onSubscribe: { graphQLRequest, _ in
36+
let subscription = try await self.api.subscribe(
3537
request: graphQLRequest.query,
36-
context: context
38+
context: self.context
3739
).get()
3840
self.subscribeReady = true
3941
return subscription
4042
}
4143
)
42-
self.context = context
43-
}
44-
45-
/// Tests that trying to run methods before `connection_init` is not allowed
46-
func testInitialize() async throws {
44+
4745
let client = Client<TokenInitPayload>(messenger: clientMessenger)
4846
let messageStream = AsyncThrowingStream<String, any Error> { continuation in
4947
client.onMessage { message, _ in
@@ -78,9 +76,26 @@ class GraphqlTransportWSTests: XCTestCase {
7876

7977
/// Tests that throwing in the authorization callback forces an unauthorized error
8078
func testAuthWithThrow() async throws {
81-
server.auth { _ in
82-
throw TestError.couldBeAnything
83-
}
79+
_ = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
80+
messenger: serverMessenger,
81+
onInit: { _ in
82+
throw TestError.couldBeAnything
83+
},
84+
onExecute: { graphQLRequest, _ in
85+
try await self.api.execute(
86+
request: graphQLRequest.query,
87+
context: self.context
88+
)
89+
},
90+
onSubscribe: { graphQLRequest, _ in
91+
let subscription = try await self.api.subscribe(
92+
request: graphQLRequest.query,
93+
context: self.context
94+
).get()
95+
self.subscribeReady = true
96+
return subscription
97+
}
98+
)
8499

85100
let client = Client<TokenInitPayload>(messenger: clientMessenger)
86101
let messageStream = AsyncThrowingStream<String, any Error> { continuation in
@@ -111,6 +126,25 @@ class GraphqlTransportWSTests: XCTestCase {
111126

112127
/// Tests a single-op conversation
113128
func testSingleOp() async throws {
129+
_ = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
130+
messenger: serverMessenger,
131+
onInit: { _ in },
132+
onExecute: { graphQLRequest, _ in
133+
try await self.api.execute(
134+
request: graphQLRequest.query,
135+
context: self.context
136+
)
137+
},
138+
onSubscribe: { graphQLRequest, _ in
139+
let subscription = try await self.api.subscribe(
140+
request: graphQLRequest.query,
141+
context: self.context
142+
).get()
143+
self.subscribeReady = true
144+
return subscription
145+
}
146+
)
147+
114148
let id = UUID().description
115149

116150
let client = Client<TokenInitPayload>(messenger: clientMessenger)
@@ -156,6 +190,25 @@ class GraphqlTransportWSTests: XCTestCase {
156190

157191
/// Tests a streaming conversation
158192
func testStreaming() async throws {
193+
_ = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
194+
messenger: serverMessenger,
195+
onInit: { _ in },
196+
onExecute: { graphQLRequest, _ in
197+
try await self.api.execute(
198+
request: graphQLRequest.query,
199+
context: self.context
200+
)
201+
},
202+
onSubscribe: { graphQLRequest, _ in
203+
let subscription = try await self.api.subscribe(
204+
request: graphQLRequest.query,
205+
context: self.context
206+
).get()
207+
self.subscribeReady = true
208+
return subscription
209+
}
210+
)
211+
159212
let id = UUID().description
160213

161214
var dataIndex = 1

0 commit comments

Comments
 (0)