1717package com .google .cloud .spanner .spi .v1 ;
1818
1919import static com .google .common .truth .Truth .assertThat ;
20+ import static org .junit .Assert .assertEquals ;
21+ import static org .junit .Assert .assertFalse ;
22+ import static org .junit .Assert .assertNotNull ;
2023import static org .junit .Assert .assertThrows ;
2124
2225import com .google .api .gax .grpc .InstantiatingGrpcChannelProvider ;
@@ -296,11 +299,11 @@ public void beginTransactionWithMutationKeyAddsRoutingHint() throws Exception {
296299 (RecordingClientCall <BeginTransactionRequest , Transaction >)
297300 harness .defaultManagedChannel .latestCall ();
298301
299- assertThat (beginDelegate .lastMessage ). isNotNull ( );
300- assertThat ( beginDelegate .lastMessage .getRoutingHint ().getDatabaseId ()). isEqualTo ( 7L );
301- assertThat ( beginDelegate . lastMessage . getRoutingHint (). getSchemaGeneration (). toStringUtf8 ())
302- . isEqualTo ( "1" );
303- assertThat (beginDelegate .lastMessage .getRoutingHint ().getKey ().isEmpty ()). isFalse ( );
302+ assertNotNull (beginDelegate .lastMessage );
303+ assertEquals ( 7L , beginDelegate .lastMessage .getRoutingHint ().getDatabaseId ());
304+ assertEquals (
305+ "1" , beginDelegate . lastMessage . getRoutingHint (). getSchemaGeneration (). toStringUtf8 () );
306+ assertFalse (beginDelegate .lastMessage .getRoutingHint ().getKey ().isEmpty ());
304307 }
305308
306309 @ Test
@@ -339,11 +342,66 @@ public void transactionCacheUpdateEnablesCommitRoutingHint() throws Exception {
339342 (RecordingClientCall <CommitRequest , CommitResponse >)
340343 harness .defaultManagedChannel .latestCall ();
341344
342- assertThat (commitDelegate .lastMessage ).isNotNull ();
343- assertThat (commitDelegate .lastMessage .getRoutingHint ().getDatabaseId ()).isEqualTo (7L );
344- assertThat (commitDelegate .lastMessage .getRoutingHint ().getSchemaGeneration ().toStringUtf8 ())
345- .isEqualTo ("1" );
346- assertThat (commitDelegate .lastMessage .getRoutingHint ().getKey ().isEmpty ()).isFalse ();
345+ assertNotNull (commitDelegate .lastMessage );
346+ assertEquals (7L , commitDelegate .lastMessage .getRoutingHint ().getDatabaseId ());
347+ assertEquals (
348+ "1" , commitDelegate .lastMessage .getRoutingHint ().getSchemaGeneration ().toStringUtf8 ());
349+ assertFalse (commitDelegate .lastMessage .getRoutingHint ().getKey ().isEmpty ());
350+ }
351+
352+ @ Test
353+ public void singleUseCommitWithMutationsRoutesUsingRoutingHint () throws Exception {
354+ TestHarness harness = createHarness ();
355+ seedCache (harness , createMutationRecipeCacheUpdate ());
356+
357+ ClientCall <CommitRequest , CommitResponse > firstCommitCall =
358+ harness .channel .newCall (SpannerGrpc .getCommitMethod (), CallOptions .DEFAULT );
359+ firstCommitCall .start (new CapturingListener <CommitResponse >(), new Metadata ());
360+ firstCommitCall .sendMessage (
361+ CommitRequest .newBuilder ()
362+ .setSession (SESSION )
363+ .setSingleUseTransaction (
364+ TransactionOptions .newBuilder ()
365+ .setReadWrite (TransactionOptions .ReadWrite .getDefaultInstance ()))
366+ .addMutations (createInsertMutation ("b" ))
367+ .build ());
368+
369+ @ SuppressWarnings ("unchecked" )
370+ RecordingClientCall <CommitRequest , CommitResponse > firstCommitDelegate =
371+ (RecordingClientCall <CommitRequest , CommitResponse >)
372+ harness .defaultManagedChannel .latestCall ();
373+
374+ assertNotNull (firstCommitDelegate .lastMessage );
375+ RoutingHint routingHint = firstCommitDelegate .lastMessage .getRoutingHint ();
376+ assertFalse (routingHint .getKey ().isEmpty ());
377+
378+ seedCache (harness , createRangeCacheUpdateForHint (routingHint ));
379+
380+ ClientCall <CommitRequest , CommitResponse > secondCommitCall =
381+ harness .channel .newCall (SpannerGrpc .getCommitMethod (), CallOptions .DEFAULT );
382+ secondCommitCall .start (new CapturingListener <CommitResponse >(), new Metadata ());
383+ secondCommitCall .sendMessage (
384+ CommitRequest .newBuilder ()
385+ .setSession (SESSION )
386+ .setSingleUseTransaction (
387+ TransactionOptions .newBuilder ()
388+ .setReadWrite (TransactionOptions .ReadWrite .getDefaultInstance ()))
389+ .addMutations (createInsertMutation ("b" ))
390+ .build ());
391+
392+ assertThat (harness .endpointCache .callCountForAddress (DEFAULT_ADDRESS )).isEqualTo (3 );
393+ assertThat (harness .endpointCache .callCountForAddress ("server-a:1234" )).isEqualTo (1 );
394+
395+ @ SuppressWarnings ("unchecked" )
396+ RecordingClientCall <CommitRequest , CommitResponse > commitDelegate =
397+ (RecordingClientCall <CommitRequest , CommitResponse >)
398+ harness .endpointCache .latestCallForAddress ("server-a:1234" );
399+
400+ assertNotNull (commitDelegate .lastMessage );
401+ assertEquals (7L , commitDelegate .lastMessage .getRoutingHint ().getDatabaseId ());
402+ assertEquals (
403+ "1" , commitDelegate .lastMessage .getRoutingHint ().getSchemaGeneration ().toStringUtf8 ());
404+ assertFalse (commitDelegate .lastMessage .getRoutingHint ().getKey ().isEmpty ());
347405 }
348406
349407 @ Test
@@ -389,12 +447,11 @@ public void commitResponseCacheUpdateEnablesSubsequentBeginRoutingHint() throws
389447 (RecordingClientCall <BeginTransactionRequest , Transaction >)
390448 harness .defaultManagedChannel .latestCall ();
391449
392- assertThat (routedBeginDelegate .lastMessage ).isNotNull ();
393- assertThat (routedBeginDelegate .lastMessage .getRoutingHint ().getDatabaseId ()).isEqualTo (7L );
394- assertThat (
395- routedBeginDelegate .lastMessage .getRoutingHint ().getSchemaGeneration ().toStringUtf8 ())
396- .isEqualTo ("1" );
397- assertThat (routedBeginDelegate .lastMessage .getRoutingHint ().getKey ().isEmpty ()).isFalse ();
450+ assertNotNull (routedBeginDelegate .lastMessage );
451+ assertEquals (7L , routedBeginDelegate .lastMessage .getRoutingHint ().getDatabaseId ());
452+ assertEquals (
453+ "1" , routedBeginDelegate .lastMessage .getRoutingHint ().getSchemaGeneration ().toStringUtf8 ());
454+ assertFalse (routedBeginDelegate .lastMessage .getRoutingHint ().getKey ().isEmpty ());
398455 }
399456
400457 @ Test
@@ -759,6 +816,13 @@ private static CacheUpdate createTwoRangeCacheUpdate() {
759816 }
760817
761818 private static CacheUpdate createMutationRoutingCacheUpdate () throws TextFormat .ParseException {
819+ return createMutationRecipeCacheUpdate ().toBuilder ()
820+ .mergeFrom (
821+ createRangeCacheUpdateForHint (RoutingHint .newBuilder ().setKey (bytes ("a" )).build ()))
822+ .build ();
823+ }
824+
825+ private static CacheUpdate createMutationRecipeCacheUpdate () throws TextFormat .ParseException {
762826 RecipeList keyRecipes =
763827 parseRecipeList (
764828 "schema_generation: \" 1\" \n "
@@ -772,13 +836,21 @@ private static CacheUpdate createMutationRoutingCacheUpdate() throws TextFormat.
772836 + " identifier: \" k\" \n "
773837 + " }\n "
774838 + "}\n " );
839+ return CacheUpdate .newBuilder ().setDatabaseId (7L ).setKeyRecipes (keyRecipes ).build ();
840+ }
841+
842+ private static CacheUpdate createRangeCacheUpdateForHint (RoutingHint hint ) {
843+ ByteString key = hint .getKey ();
844+ ByteString limitKey =
845+ hint .getLimitKey ().isEmpty ()
846+ ? key .concat (ByteString .copyFrom (new byte [] {0 }))
847+ : hint .getLimitKey ();
775848 return CacheUpdate .newBuilder ()
776849 .setDatabaseId (7L )
777- .setKeyRecipes (keyRecipes )
778850 .addRange (
779851 Range .newBuilder ()
780- .setStartKey (bytes ( "a" ) )
781- .setLimitKey (bytes ( "m" ) )
852+ .setStartKey (key )
853+ .setLimitKey (limitKey )
782854 .setGroupUid (1L )
783855 .setSplitId (1L )
784856 .setGeneration (bytes ("1" )))
0 commit comments