Skip to content

Commit a726bd3

Browse files
feat!: Messenger is send-only
To support receiving messages, we added `listen` functions to server and client. This resolves the confusing ownership rules by avoiding `onReceive` callbacks in Messenger.
1 parent fcf29f7 commit a726bd3

6 files changed

Lines changed: 79 additions & 77 deletions

File tree

README.md

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,21 @@ import GraphQLTransportWS
2626

2727
/// Messenger wrapper for WebSockets
2828
class WebSocketMessenger: Messenger {
29-
private weak var websocket: WebSocket?
30-
private var onReceive: (String) async throws -> Void = { _ in }
29+
let websocket: WebSocket?
3130

3231
init(websocket: WebSocket) {
3332
self.websocket = websocket
34-
websocket.onText { _, message in
35-
try await self.onReceive(message)
36-
}
3733
}
3834

3935
func send<S>(_ message: S) where S: Collection, S.Element == Character async throws {
40-
guard let websocket = websocket else { return }
4136
try await websocket.send(message)
4237
}
4338

44-
func onReceive(callback: @escaping (String) async throws -> Void) {
45-
self.onReceive = callback
46-
}
47-
4839
func error(_ message: String, code: Int) async throws {
49-
guard let websocket = websocket else { return }
5040
try await websocket.send("\(code): \(message)")
5141
}
5242

5343
func close() async throws {
54-
guard let websocket = websocket else { return }
5544
try await websocket.close()
5645
}
5746
}
@@ -85,6 +74,12 @@ routes.webSocket(
8574
)
8675
}
8776
)
77+
let incoming = AsyncStream<String> { continuation in
78+
websocket.onText { _, message in
79+
continuation.yield(message)
80+
}
81+
}
82+
try await server.listen(to: incoming)
8883
}
8984
)
9085
```
@@ -125,12 +120,3 @@ This example would require `connection_init` message from the client to look lik
125120
```
126121

127122
If the `payload` field is not required on your server, you may make Server's generic declaration optional like `Server<Payload?>`
128-
129-
## Memory Management
130-
131-
Memory ownership among the Server, Client, and Messenger may seem a little backwards. This is because the Swift/Vapor WebSocket
132-
implementation persists WebSocket objects long after their callback and they are expected to retain strong memory references to the
133-
objects required for responses. In order to align cleanly and avoid memory cycles, Server and Client are injected strongly into Messenger
134-
callbacks, and only hold weak references to their Messenger. This means that Messenger objects (or their enclosing WebSocket) must
135-
be persisted to have the connected Server or Client objects function. That is, if a Server's Messenger falls out of scope and deinitializes,
136-
the Server will no longer respond to messages.

Sources/GraphQLTransportWS/Client.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@ import Foundation
22
import GraphQL
33

44
/// Client is an open-ended implementation of the client side of the protocol. It parses and adds callbacks for each type of server respose.
5-
public class Client<InitPayload: Equatable & Codable> {
6-
// We keep this weak because we strongly inject this object into the messenger callback
7-
weak var messenger: Messenger?
5+
public class Client<InitPayload: Equatable & Codable>: @unchecked Sendable {
6+
let messenger: Messenger
87

98
var onConnectionAck: (ConnectionAckResponse, Client) async throws -> Void = { _, _ in }
109
var onNext: (NextResponse, Client) async throws -> Void = { _, _ in }
@@ -23,7 +22,12 @@ public class Client<InitPayload: Equatable & Codable> {
2322
messenger: Messenger
2423
) {
2524
self.messenger = messenger
26-
messenger.onReceive { message in
25+
}
26+
27+
/// Listen and react to the provided async sequence of server messages. This function will block until the stream is completed.
28+
/// - Parameter incoming: The server message sequence that the client should react to.
29+
public func listen<A: AsyncSequence & Sendable>(to incoming: A) async throws -> Void where A.Element == String {
30+
for try await message in incoming {
2731
try await self.onMessage(message, self)
2832

2933
// Detect and ignore error responses.
@@ -108,7 +112,6 @@ public class Client<InitPayload: Equatable & Codable> {
108112

109113
/// Send a `connection_init` request through the messenger
110114
public func sendConnectionInit(payload: InitPayload) async throws {
111-
guard let messenger = messenger else { return }
112115
try await messenger.send(
113116
ConnectionInitRequest(
114117
payload: payload
@@ -118,7 +121,6 @@ public class Client<InitPayload: Equatable & Codable> {
118121

119122
/// Send a `subscribe` request through the messenger
120123
public func sendStart(payload: GraphQLRequest, id: String) async throws {
121-
guard let messenger = messenger else { return }
122124
try await messenger.send(
123125
SubscribeRequest(
124126
payload: payload,
@@ -129,7 +131,6 @@ public class Client<InitPayload: Equatable & Codable> {
129131

130132
/// Send a `complete` request through the messenger
131133
public func sendStop(id: String) async throws {
132-
guard let messenger = messenger else { return }
133134
try await messenger.send(
134135
CompleteRequest(
135136
id: id
@@ -139,7 +140,6 @@ public class Client<InitPayload: Equatable & Codable> {
139140

140141
/// Send an error through the messenger and close the connection
141142
private func error(_ error: GraphQLTransportWSError) async throws {
142-
guard let messenger = messenger else { return }
143143
try await messenger.error(error.message, code: error.code.rawValue)
144144
}
145145
}

Sources/GraphQLTransportWS/Messenger.swift

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import Foundation
22

3-
/// Protocol for an object that can send and recieve messages. This allows mocking in tests
4-
public protocol Messenger: AnyObject {
5-
// AnyObject compliance requires that the implementing object is a class and we can reference it weakly
6-
3+
/// Protocol for an object that can send messages. This allows mocking in tests
4+
public protocol Messenger {
75
/// Send a message through this messenger
86
/// - Parameter message: The message to send
9-
func send<S>(_ message: S) async throws -> Void where S: Collection, S.Element == Character
10-
11-
/// Set the callback that should be run when a message is recieved
12-
func onReceive(callback: @escaping (String) async throws -> Void)
7+
func send<S: Sendable>(_ message: S) async throws -> Void where S: Collection, S.Element == Character
138

149
/// Close the messenger
1510
func close() async throws

Sources/GraphQLTransportWS/Server.swift

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class Server<
1212
SubscriptionSequenceType.Element == GraphQLResult
1313
{
1414
// We keep this weak because we strongly inject this object into the messenger callback
15-
weak var messenger: Messenger?
15+
let messenger: Messenger
1616

1717
let onInit: (InitPayload) async throws -> InitPayloadResult
1818
let onExecute: (GraphQLRequest, InitPayloadResult) async throws -> GraphQLResult
@@ -47,10 +47,12 @@ public class Server<
4747
self.onInit = onInit
4848
self.onExecute = onExecute
4949
self.onSubscribe = onSubscribe
50+
}
5051

51-
messenger.onReceive { message in
52-
guard let messenger = self.messenger else { return }
53-
52+
/// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed.
53+
/// - Parameter incoming: The client message sequence that the server should react to.
54+
public func listen<A: AsyncSequence & Sendable>(to incoming: A) async throws -> Void where A.Element == String {
55+
for try await message in incoming {
5456
try await self.onMessage(message)
5557

5658
// Detect and ignore error responses.
@@ -188,7 +190,7 @@ public class Server<
188190
} catch {
189191
try await sendError(error, id: id)
190192
}
191-
try await messenger?.close()
193+
try await messenger.close()
192194
}
193195
}
194196

@@ -208,15 +210,13 @@ public class Server<
208210

209211
/// Send a `connection_ack` response through the messenger
210212
private func sendConnectionAck(_ payload: [String: Map]? = nil) async throws {
211-
guard let messenger = messenger else { return }
212213
try await messenger.send(
213214
ConnectionAckResponse(payload: payload).toJSON(encoder)
214215
)
215216
}
216217

217218
/// Send a `next` response through the messenger
218219
private func sendNext(_ payload: GraphQLResult? = nil, id: String) async throws {
219-
guard let messenger = messenger else { return }
220220
try await messenger.send(
221221
NextResponse(
222222
payload: payload,
@@ -227,7 +227,6 @@ public class Server<
227227

228228
/// Send a `complete` response through the messenger
229229
private func sendComplete(id: String) async throws {
230-
guard let messenger = messenger else { return }
231230
try await messenger.send(
232231
CompleteResponse(
233232
id: id
@@ -238,7 +237,6 @@ public class Server<
238237

239238
/// Send an `error` response through the messenger
240239
private func sendError(_ errors: [Error], id: String) async throws {
241-
guard let messenger = messenger else { return }
242240
try await messenger.send(
243241
ErrorResponse(
244242
errors,
@@ -260,7 +258,6 @@ public class Server<
260258

261259
/// Send an error through the messenger and close the connection
262260
private func error(_ error: GraphQLTransportWSError) async throws {
263-
guard let messenger = messenger else { return }
264261
try await messenger.error(error.message, code: error.code.rawValue)
265262
}
266263
}

Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,13 @@ class GraphqlTransportWSTests: XCTestCase {
1414
let api = TestAPI()
1515

1616
override func setUp() {
17-
// Point the client and server at each other
1817
clientMessenger = TestMessenger()
1918
serverMessenger = TestMessenger()
20-
clientMessenger.other = serverMessenger
21-
serverMessenger.other = clientMessenger
2219
}
2320

2421
/// Tests that trying to run methods before `connection_init` is not allowed
2522
func testInitialize() async throws {
26-
_ = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
23+
let server = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
2724
messenger: serverMessenger,
2825
onInit: { _ in },
2926
onExecute: { graphQLRequest, _ in
@@ -41,8 +38,16 @@ class GraphqlTransportWSTests: XCTestCase {
4138
return subscription
4239
}
4340
)
44-
4541
let client = Client<TokenInitPayload>(messenger: clientMessenger)
42+
let serverStream = serverMessenger.stream
43+
let clientStream = clientMessenger.stream
44+
Task {
45+
try await server.listen(to: clientStream)
46+
}
47+
Task {
48+
try await client.listen(to: serverStream)
49+
}
50+
4651
let messageStream = AsyncThrowingStream<String, any Error> { continuation in
4752
client.onMessage { message, _ in
4853
continuation.yield(message)
@@ -76,7 +81,7 @@ class GraphqlTransportWSTests: XCTestCase {
7681

7782
/// Tests that throwing in the authorization callback forces an unauthorized error
7883
func testAuthWithThrow() async throws {
79-
_ = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
84+
let server = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
8085
messenger: serverMessenger,
8186
onInit: { _ in
8287
throw TestError.couldBeAnything
@@ -96,8 +101,16 @@ class GraphqlTransportWSTests: XCTestCase {
96101
return subscription
97102
}
98103
)
99-
100104
let client = Client<TokenInitPayload>(messenger: clientMessenger)
105+
let clientStream = clientMessenger.stream
106+
let serverStream = serverMessenger.stream
107+
Task {
108+
try await server.listen(to: clientStream)
109+
}
110+
Task {
111+
try await client.listen(to: serverStream)
112+
}
113+
101114
let messageStream = AsyncThrowingStream<String, any Error> { continuation in
102115
client.onMessage { message, _ in
103116
continuation.yield(message)
@@ -126,7 +139,7 @@ class GraphqlTransportWSTests: XCTestCase {
126139

127140
/// Tests a single-op conversation
128141
func testSingleOp() async throws {
129-
_ = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
142+
let server = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
130143
messenger: serverMessenger,
131144
onInit: { _ in },
132145
onExecute: { graphQLRequest, _ in
@@ -144,10 +157,18 @@ class GraphqlTransportWSTests: XCTestCase {
144157
return subscription
145158
}
146159
)
160+
let client = Client<TokenInitPayload>(messenger: clientMessenger)
161+
let clientStream = clientMessenger.stream
162+
let serverStream = serverMessenger.stream
163+
Task {
164+
try await server.listen(to: clientStream)
165+
}
166+
Task {
167+
try await client.listen(to: serverStream)
168+
}
147169

148170
let id = UUID().description
149171

150-
let client = Client<TokenInitPayload>(messenger: clientMessenger)
151172
let messageStream = AsyncThrowingStream<String, any Error> { continuation in
152173
client.onConnectionAck { _, client in
153174
try await client.sendStart(
@@ -190,7 +211,7 @@ class GraphqlTransportWSTests: XCTestCase {
190211

191212
/// Tests a streaming conversation
192213
func testStreaming() async throws {
193-
_ = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
214+
let server = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
194215
messenger: serverMessenger,
195216
onInit: { _ in },
196217
onExecute: { graphQLRequest, _ in
@@ -208,13 +229,21 @@ class GraphqlTransportWSTests: XCTestCase {
208229
return subscription
209230
}
210231
)
232+
let client = Client<TokenInitPayload>(messenger: clientMessenger)
233+
let clientStream = clientMessenger.stream
234+
let serverStream = serverMessenger.stream
235+
Task {
236+
try await server.listen(to: clientStream)
237+
}
238+
Task {
239+
try await client.listen(to: serverStream)
240+
}
211241

212242
let id = UUID().description
213243

214244
var dataIndex = 1
215245
let dataIndexMax = 3
216246

217-
let client = Client<TokenInitPayload>(messenger: clientMessenger)
218247
let messageStream = AsyncThrowingStream<String, any Error> { continuation in
219248
client.onConnectionAck { _, client in
220249
try await client.sendStart(

Tests/GraphQLTransportWSTests/Utils/TestMessenger.swift

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,27 @@ import Foundation
44
@testable import GraphQLTransportWS
55

66
/// Messenger for simple testing that doesn't require starting up a websocket server.
7-
///
8-
/// Note that this only retains a weak reference to 'other', so the client should retain references
9-
/// or risk them being deinitialized early
10-
class TestMessenger: Messenger, @unchecked Sendable {
11-
weak var other: TestMessenger?
12-
var onReceive: (String) async throws -> Void = { _ in }
13-
let queue: DispatchQueue = .init(label: "Test messenger")
14-
15-
init() {}
16-
17-
func send<S: Sendable>(_ message: S) async throws where S: Collection, S.Element == Character {
18-
guard let other = other else {
19-
return
20-
}
21-
try await other.onReceive(String(message))
7+
actor TestMessenger: Messenger {
8+
/// An async stream of the messages sent through this messenger.
9+
let stream: AsyncStream<String>
10+
private var continuation: AsyncStream<String>.Continuation
11+
12+
init() {
13+
let (stream, continuation) = AsyncStream<String>.makeStream()
14+
self.stream = stream
15+
self.continuation = continuation
2216
}
2317

24-
func onReceive(callback: @escaping (String) async throws -> Void) {
25-
onReceive = callback
18+
func send<S: Sendable>(_ message: S) async throws where S: Collection, S.Element == Character {
19+
continuation.yield(String(message))
2620
}
2721

2822
func error(_ message: String, code: Int) async throws {
29-
try await send("\(code): \(message)")
23+
continuation.yield("\(code): \(message)")
24+
continuation.finish()
3025
}
3126

3227
func close() {
33-
// This is a testing no-op
28+
continuation.finish()
3429
}
3530
}

0 commit comments

Comments
 (0)