Skip to content
Merged
18 changes: 18 additions & 0 deletions app/src/main/java/to/bitkit/repositories/LightningRepo.kt
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ class LightningRepo @Inject constructor(
customRgsServerUrl: String? = null,
eventHandler: NodeEventHandler? = null,
channelMigration: ChannelDataMigration? = null,
shouldValidateGraph: Boolean = true,
): Result<Unit> = withContext(bgDispatcher) {
if (_isRecoveryMode.value) {
return@withContext Result.failure(RecoveryModeError())
Expand Down Expand Up @@ -313,6 +314,23 @@ class LightningRepo @Inject constructor(
updateGeoBlockState()
refreshChannelCache()

// Validate network graph has trusted peers (RGS cache can become stale)
if (shouldValidateGraph && !lightningService.validateNetworkGraph()) {
Logger.warn("Network graph is stale, resetting and restarting...", context = TAG)
lightningService.stop()
lightningService.resetNetworkGraph(walletIndex)
return@withContext start(
walletIndex = walletIndex,
timeout = timeout,
shouldRetry = shouldRetry,
customServerUrl = customServerUrl,
customRgsServerUrl = customRgsServerUrl,
eventHandler = eventHandler,
channelMigration = channelMigration,
shouldValidateGraph = false, // Prevent infinite loop
)
Comment thread
jvsena42 marked this conversation as resolved.
Outdated
}

// Post-startup tasks (non-blocking)
connectToTrustedPeers().onFailure {
Logger.error("Failed to connect to trusted peers", it, context = TAG)
Expand Down
43 changes: 43 additions & 0 deletions app/src/main/java/to/bitkit/services/LightningService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,49 @@ class LightningService @Inject constructor(
Logger.info("LDK storage wiped", context = TAG)
}

/**
* Resets the network graph cache, forcing a full RGS sync on next startup.
* This is useful when the cached graph is stale or missing nodes.
* Note: Node must be stopped before calling this.
*/
fun resetNetworkGraph(walletIndex: Int) {
if (node != null) throw ServiceError.NodeStillRunning()
Logger.warn("Resetting network graph cache…", context = TAG)
val ldkPath = Path(Env.ldkStoragePath(walletIndex)).toFile()
val graphFile = ldkPath.resolve("network_graph")
if (graphFile.exists()) {
graphFile.delete()
Logger.info("Network graph cache deleted", context = TAG)
} else {
Logger.info("No network graph cache found", context = TAG)
}
}

/**
* Validates that all trusted peers are present in the network graph.
* Returns false if any trusted peer is missing, indicating the graph cache is stale.
*/
fun validateNetworkGraph(): Boolean {
val node = this.node ?: return true
val graph = node.networkGraph()
val graphNodes = graph.listNodes().toSet()
if (graphNodes.isEmpty()) {
Logger.debug("Network graph is empty, skipping validation", context = TAG)
return true
Comment thread
jvsena42 marked this conversation as resolved.
}
val missingPeers = trustedPeers.filter { it.nodeId !in graphNodes }
if (missingPeers.isNotEmpty()) {
Logger.warn(
"Network graph missing ${missingPeers.size} trusted peers: " +
missingPeers.joinToString { it.nodeId.take(20) + "..." },
context = TAG,
)
return false
}
Logger.debug("Network graph validated: all ${trustedPeers.size} trusted peers present", context = TAG)
return true
}

suspend fun sync() {
val node = this.node ?: throw ServiceError.NodeNotSetup()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class LightningNodeServiceTest : BaseUnitTest() {
anyOrNull(),
anyOrNull(),
anyOrNull(),
any(),
)
} doAnswer {
capturedHandler = it.getArgument(5) as? NodeEventHandler
Expand Down
6 changes: 6 additions & 0 deletions app/src/test/java/to/bitkit/repositories/LightningRepoTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class LightningRepoTest : BaseUnitTest() {
whenever(lightningService.setup(any(), anyOrNull(), anyOrNull(), anyOrNull(), anyOrNull())).thenReturn(Unit)
whenever(lightningService.start(anyOrNull(), any())).thenReturn(Unit)
whenever(lightningService.sync()).thenReturn(Unit)
whenever(lightningService.validateNetworkGraph()).thenReturn(true)
whenever(settingsStore.data).thenReturn(flowOf(SettingsData()))
val blocktank = mock<BlocktankService>()
whenever(coreService.blocktank).thenReturn(blocktank)
Expand All @@ -107,6 +108,7 @@ class LightningRepoTest : BaseUnitTest() {
whenever(lightningService.node).thenReturn(mock())
whenever(lightningService.setup(any(), anyOrNull(), anyOrNull(), anyOrNull(), anyOrNull())).thenReturn(Unit)
whenever(lightningService.start(anyOrNull(), any())).thenReturn(Unit)
whenever(lightningService.validateNetworkGraph()).thenReturn(true)
val blocktank = mock<BlocktankService>()
whenever(coreService.blocktank).thenReturn(blocktank)
whenever(blocktank.info(any())).thenReturn(null)
Expand Down Expand Up @@ -388,6 +390,7 @@ class LightningRepoTest : BaseUnitTest() {
whenever(lightningService.node).thenReturn(mock())
whenever(lightningService.setup(any(), anyOrNull(), anyOrNull(), anyOrNull(), anyOrNull())).thenReturn(Unit)
whenever(lightningService.start(anyOrNull(), any())).thenReturn(Unit)
whenever(lightningService.validateNetworkGraph()).thenReturn(true)
whenever(lightningService.sync()).thenThrow(RuntimeException("Sync failed"))
whenever(settingsStore.data).thenReturn(flowOf(SettingsData()))
val blocktank = mock<BlocktankService>()
Expand Down Expand Up @@ -621,6 +624,7 @@ class LightningRepoTest : BaseUnitTest() {
whenever(lightningService.node).thenReturn(null)
whenever(lightningService.setup(any(), anyOrNull(), anyOrNull(), anyOrNull(), anyOrNull())).thenReturn(Unit)
whenever(lightningService.start(anyOrNull(), any())).thenReturn(Unit)
whenever(lightningService.validateNetworkGraph()).thenReturn(true)
whenever(settingsStore.data).thenReturn(flowOf(SettingsData()))

val blocktank = mock<BlocktankService>()
Expand Down Expand Up @@ -665,6 +669,7 @@ class LightningRepoTest : BaseUnitTest() {
whenever(lightningService.node).thenReturn(null)
whenever(lightningService.setup(any(), anyOrNull(), anyOrNull(), anyOrNull(), anyOrNull())).thenReturn(Unit)
whenever(lightningService.start(anyOrNull(), any())).thenReturn(Unit)
whenever(lightningService.validateNetworkGraph()).thenReturn(true)
whenever(settingsStore.data).thenReturn(flowOf(SettingsData()))

val blocktank = mock<BlocktankService>()
Expand All @@ -690,6 +695,7 @@ class LightningRepoTest : BaseUnitTest() {

// lightningService.start() succeeds (state becomes Running at line 241)
whenever(lightningService.start(anyOrNull(), any())).thenReturn(Unit)
whenever(lightningService.validateNetworkGraph()).thenReturn(true)
// lightningService.nodeId throws during syncState() (called at line 244, AFTER state = Running)
whenever(lightningService.nodeId).thenThrow(RuntimeException("error during syncState"))

Expand Down
39 changes: 34 additions & 5 deletions app/src/test/java/to/bitkit/ui/WalletViewModelTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,18 @@ class WalletViewModelTest : BaseUnitTest() {
whenever(testWalletRepo.walletExists()).thenReturn(true)
whenever(testLightningRepo.lightningState).thenReturn(lightningState)
whenever(testLightningRepo.isRecoveryMode).thenReturn(isRecoveryMode)
whenever(testLightningRepo.start(any(), anyOrNull(), any(), anyOrNull(), anyOrNull(), anyOrNull(), anyOrNull()))
.thenReturn(Result.success(Unit))
whenever(
testLightningRepo.start(
any(),
anyOrNull(),
any(),
anyOrNull(),
anyOrNull(),
anyOrNull(),
anyOrNull(),
any(),
),
).thenReturn(Result.success(Unit))

val testSut = WalletViewModel(
context = context,
Expand All @@ -262,7 +272,16 @@ class WalletViewModelTest : BaseUnitTest() {
testSut.start()
advanceUntilIdle()

verify(testLightningRepo).start(any(), anyOrNull(), any(), anyOrNull(), anyOrNull(), anyOrNull(), anyOrNull())
verify(testLightningRepo).start(
any(),
anyOrNull(),
any(),
anyOrNull(),
anyOrNull(),
anyOrNull(),
anyOrNull(),
any(),
)
verify(testWalletRepo).refreshBip21()
}

Expand All @@ -282,8 +301,18 @@ class WalletViewModelTest : BaseUnitTest() {
whenever(testWalletRepo.restoreWallet(any(), anyOrNull())).thenReturn(Result.success(Unit))
whenever(testLightningRepo.lightningState).thenReturn(lightningState)
whenever(testLightningRepo.isRecoveryMode).thenReturn(isRecoveryMode)
whenever(testLightningRepo.start(any(), anyOrNull(), any(), anyOrNull(), anyOrNull(), anyOrNull(), anyOrNull()))
.thenReturn(Result.success(Unit))
whenever(
testLightningRepo.start(
any(),
anyOrNull(),
any(),
anyOrNull(),
anyOrNull(),
anyOrNull(),
anyOrNull(),
any(),
),
).thenReturn(Result.success(Unit))

val testSut = WalletViewModel(
context = context,
Expand Down
Loading