1919import com .google .api .core .InternalApi ;
2020import com .google .api .gax .grpc .InstantiatingGrpcChannelProvider ;
2121import com .google .protobuf .ByteString ;
22+ import com .google .spanner .v1 .BeginTransactionRequest ;
2223import com .google .spanner .v1 .CommitRequest ;
2324import com .google .spanner .v1 .ExecuteSqlRequest ;
2425import com .google .spanner .v1 .PartialResultSet ;
@@ -52,6 +53,8 @@ final class KeyAwareChannel extends ManagedChannel {
5253 private static final String STREAMING_SQL_METHOD =
5354 "google.spanner.v1.Spanner/ExecuteStreamingSql" ;
5455 private static final String UNARY_SQL_METHOD = "google.spanner.v1.Spanner/ExecuteSql" ;
56+ private static final String BEGIN_TRANSACTION_METHOD =
57+ "google.spanner.v1.Spanner/BeginTransaction" ;
5558 private static final String COMMIT_METHOD = "google.spanner.v1.Spanner/Commit" ;
5659 private static final String ROLLBACK_METHOD = "google.spanner.v1.Spanner/Rollback" ;
5760
@@ -162,6 +165,7 @@ private static boolean isKeyAware(MethodDescriptor<?, ?> methodDescriptor) {
162165 return STREAMING_READ_METHOD .equals (method )
163166 || STREAMING_SQL_METHOD .equals (method )
164167 || UNARY_SQL_METHOD .equals (method )
168+ || BEGIN_TRANSACTION_METHOD .equals (method )
165169 || COMMIT_METHOD .equals (method )
166170 || ROLLBACK_METHOD .equals (method );
167171 }
@@ -185,12 +189,17 @@ private void clearAffinity(ByteString transactionId) {
185189 transactionAffinities .remove (transactionId );
186190 }
187191
188- private void recordAffinity (ByteString transactionId , @ Nullable ChannelEndpoint endpoint ) {
192+ void clearTransactionAffinity (ByteString transactionId ) {
193+ clearAffinity (transactionId );
194+ }
195+
196+ private void recordAffinity (
197+ ByteString transactionId , @ Nullable ChannelEndpoint endpoint , boolean allowDefault ) {
189198 if (transactionId == null || transactionId .isEmpty () || endpoint == null ) {
190199 return ;
191200 }
192201 String address = endpoint .getAddress ();
193- if (defaultEndpointAddress .equals (address )) {
202+ if (! allowDefault && defaultEndpointAddress .equals (address )) {
194203 return ;
195204 }
196205 transactionAffinities .put (transactionId , address );
@@ -238,6 +247,7 @@ static final class KeyAwareClientCall<RequestT, ResponseT>
238247 private ChannelFinder channelFinder ;
239248 @ Nullable private ChannelEndpoint selectedEndpoint ;
240249 @ Nullable private ByteString transactionIdToClear ;
250+ private boolean allowDefaultAffinity ;
241251
242252 KeyAwareClientCall (
243253 KeyAwareChannel parentChannel ,
@@ -295,6 +305,8 @@ public void sendMessage(RequestT message) {
295305 }
296306 }
297307 message = (RequestT ) reqBuilder .build ();
308+ } else if (message instanceof BeginTransactionRequest ) {
309+ allowDefaultAffinity = true ;
298310 } else if (message instanceof CommitRequest ) {
299311 CommitRequest request = (CommitRequest ) message ;
300312 if (!request .getTransactionId ().isEmpty ()) {
@@ -309,7 +321,8 @@ public void sendMessage(RequestT message) {
309321 }
310322 } else {
311323 throw new IllegalStateException (
312- "Only read, query, commit, and rollback requests are supported for key-aware calls." );
324+ "Only read, query, begin transaction, commit, and rollback requests are supported for"
325+ + " key-aware calls." );
313326 }
314327
315328 if (endpoint == null ) {
@@ -343,7 +356,7 @@ public void cancel(@Nullable String message, @Nullable Throwable cause) {
343356 }
344357
345358 void maybeRecordAffinity (ByteString transactionId ) {
346- parentChannel .recordAffinity (transactionId , selectedEndpoint );
359+ parentChannel .recordAffinity (transactionId , selectedEndpoint , allowDefaultAffinity );
347360 }
348361
349362 void maybeClearAffinity () {
@@ -378,6 +391,12 @@ public void onMessage(ResponseT message) {
378391 if (transactionId != null ) {
379392 call .maybeRecordAffinity (transactionId );
380393 }
394+ } else if (message instanceof Transaction ) {
395+ Transaction response = (Transaction ) message ;
396+ ByteString transactionId = transactionIdFromTransaction (response );
397+ if (transactionId != null ) {
398+ call .maybeRecordAffinity (transactionId );
399+ }
381400 }
382401 super .onMessage (message );
383402 }
0 commit comments