Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions app/src/main/java/to/bitkit/di/HttpModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ import io.ktor.client.plugins.defaultRequest
import io.ktor.client.plugins.logging.LogLevel
import io.ktor.client.plugins.logging.Logging
import io.ktor.client.plugins.logging.LoggingConfig
import io.ktor.client.request.head
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.serialization.kotlinx.json.json
import kotlinx.serialization.json.Json
import to.bitkit.utils.UrlValidator
import to.bitkit.utils.AppError
import to.bitkit.utils.Logger
import javax.inject.Singleton
import io.ktor.client.plugins.logging.Logger as KtorLogger
Expand Down Expand Up @@ -43,6 +47,17 @@ object HttpModule {
}
}

@Provides
@Singleton
fun provideUrlValidator(httpClient: HttpClient) = UrlValidator { url ->
runCatching {
val response = httpClient.head(url)
if (!response.status.isSuccess()) {
throw AppError("Server returned '${response.status}'")
}
}
}

@Suppress("MagicNumber")
private fun HttpTimeoutConfig.defaultTimeoutConfig() {
requestTimeoutMillis = 60_000
Expand Down
12 changes: 12 additions & 0 deletions app/src/main/java/to/bitkit/repositories/LightningRepo.kt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ import to.bitkit.services.NodeEventHandler
import to.bitkit.utils.AppError
import to.bitkit.utils.Logger
import to.bitkit.utils.ServiceError
import to.bitkit.utils.UrlValidator
import java.io.File
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
Expand All @@ -105,6 +106,7 @@ class LightningRepo @Inject constructor(
private val preActivityMetadataRepo: PreActivityMetadataRepo,
private val connectivityRepo: ConnectivityRepo,
private val vssBackupClientLdk: VssBackupClientLdk,
private val urlValidator: UrlValidator,
) {
private val _lightningState = MutableStateFlow(LightningState())
val lightningState = _lightningState.asStateFlow()
Expand Down Expand Up @@ -619,6 +621,8 @@ class LightningRepo @Inject constructor(
suspend fun restartWithRgsServer(newRgsUrl: String): Result<Unit> = withContext(bgDispatcher) {
Logger.info("Changing ldk-node RGS server to: '$newRgsUrl'", context = TAG)

validateRgsUrl(newRgsUrl).onFailure { return@withContext Result.failure(it) }

waitForNodeToStop().onFailure { return@withContext Result.failure(it) }
stop().onFailure {
Logger.error("Failed to stop node during RGS server change", it, context = TAG)
Expand All @@ -640,6 +644,14 @@ class LightningRepo @Inject constructor(
}
}

private suspend fun validateRgsUrl(url: String): Result<Unit> = withContext(bgDispatcher) {
val initialTimestamp = 0
val testUrl = "${url.trimEnd('/')}/$initialTimestamp"
urlValidator.validate(testUrl).onFailure {
Logger.warn("RGS server unreachable at '$testUrl'", it, context = TAG)
Comment thread
jvsena42 marked this conversation as resolved.
Outdated
}
}

suspend fun getBalanceForAddressType(addressType: AddressType): Result<ULong> = withContext(bgDispatcher) {
executeWhenNodeRunning("getBalanceForAddressType") {
runCatching {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package to.bitkit.ui.settings.advanced

import androidx.compose.runtime.Stable
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import dagger.hilt.android.lifecycle.HiltViewModel
Expand All @@ -15,6 +16,7 @@ import to.bitkit.data.SettingsStore
import to.bitkit.di.BgDispatcher
import to.bitkit.env.Env
import to.bitkit.repositories.LightningRepo
import java.net.URI
import javax.inject.Inject

@HiltViewModel
Expand All @@ -24,6 +26,14 @@ class RgsServerViewModel @Inject constructor(
private val lightningRepo: LightningRepo,
) : ViewModel() {

companion object {
private val HOSTNAME_PATTERN = Regex(
"^([a-z\\d]([a-z\\d-]*[a-z\\d])*\\.)+[a-z]{2,}|(\\d{1,3}\\.){3}\\d{1,3}$",
RegexOption.IGNORE_CASE,
)
private val PATH_PATTERN = Regex("^(/[a-zA-Z\\d_.~%+-]*)*$")
}

private val _uiState = MutableStateFlow(RgsServerUiState())
val uiState: StateFlow<RgsServerUiState> = _uiState.asStateFlow()

Expand Down Expand Up @@ -110,23 +120,29 @@ class RgsServerViewModel @Inject constructor(
}

private fun isValidURL(data: String): Boolean {
val pattern = Regex(
"^(https?://)?" + // protocol
"((([a-z\\d]([a-z\\d-]*[a-z\\d])*)\\.)+[a-z]{2,}|" + // domain name
"((\\d{1,3}\\.){3}\\d{1,3}))" + // IP (v4) address
"(:\\d+)?(/[-a-z\\d%_.~+]*)*", // port and path
RegexOption.IGNORE_CASE
)

// Allow localhost in development mode
if (Env.isDebug && data.contains("localhost")) {
return true
val normalized = if (!data.startsWith("http://") && !data.startsWith("https://")) {
"https://$data"
} else {
data
}

return pattern.matches(data)
return try {
val uri = URI(normalized)
val hostname = uri.host ?: return false

if (Env.isDebug && hostname == "localhost") return true

if (!HOSTNAME_PATTERN.matches(hostname)) return false

val path = uri.path.orEmpty()
path.isEmpty() || PATH_PATTERN.matches(path)
} catch (_: Throwable) {
false
}
}
}

@Stable
data class RgsServerUiState(
val connectedRgsUrl: String? = null,
val rgsUrl: String = "",
Expand Down
5 changes: 5 additions & 0 deletions app/src/main/java/to/bitkit/utils/UrlValidator.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package to.bitkit.utils

fun interface UrlValidator {
suspend fun validate(url: String): Result<Unit>
}
47 changes: 0 additions & 47 deletions app/src/main/java/to/bitkit/viewmodels/WalletViewModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -328,53 +328,6 @@ class WalletViewModel @Inject constructor(
}
}

private suspend fun checkForOrphanedChannelMonitorRecovery() {
if (migrationService.isChannelRecoveryChecked()) return

Logger.info("Running one-time channel monitor recovery check", context = TAG)

val allMonitorsRetrieved = runCatching {
val allRetrieved = migrationService.fetchRNRemoteLdkData()
// don't overwrite channel manager, we only need the monitors for the sweep
val channelMigration = buildChannelMigrationIfAvailable()?.let {
ChannelDataMigration(channelManager = null, channelMonitors = it.channelMonitors)
}

if (channelMigration == null) {
Logger.info("No channel monitors found on RN backup", context = TAG)
return@runCatching allRetrieved
}

Logger.info(
"Found ${channelMigration.channelMonitors.size} monitors on RN backup, attempting recovery",
context = TAG,
)

lightningRepo.stop().onFailure {
Logger.error("Failed to stop node for channel recovery", it, context = TAG)
}
delay(CHANNEL_RECOVERY_RESTART_DELAY_MS)
lightningRepo.start(channelMigration = channelMigration, shouldRetry = false)
.onSuccess {
migrationService.consumePendingChannelMigration()
walletRepo.syncNodeAndWallet()
walletRepo.syncBalances()
Logger.info("Channel monitor recovery complete", context = TAG)
}
.onFailure {
Logger.error("Failed to restart node after channel recovery", it, context = TAG)
}

allRetrieved
}.getOrDefault(false)

if (allMonitorsRetrieved) {
migrationService.markChannelRecoveryChecked()
} else {
Logger.warn("Some monitors failed to download, will retry on next startup", context = TAG)
}
}

fun stop() {
if (!walletExists) return

Expand Down
75 changes: 75 additions & 0 deletions app/src/test/java/to/bitkit/repositories/LightningRepoTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import to.bitkit.services.LightningService
import to.bitkit.services.LnurlService
import to.bitkit.services.LspNotificationsService
import to.bitkit.test.BaseUnitTest
import to.bitkit.utils.UrlValidator
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
Expand All @@ -72,6 +73,7 @@ class LightningRepoTest : BaseUnitTest() {
private val lnurlService = mock<LnurlService>()
private val connectivityRepo = mock<ConnectivityRepo>()
private val vssBackupClientLdk = mock<VssBackupClientLdk>()
private val urlValidator = UrlValidator { Result.success(Unit) }

@Before
fun setUp() = runBlocking {
Expand All @@ -94,6 +96,7 @@ class LightningRepoTest : BaseUnitTest() {
preActivityMetadataRepo = preActivityMetadataRepo,
connectivityRepo = connectivityRepo,
vssBackupClientLdk = vssBackupClientLdk,
urlValidator = urlValidator,
)
}

Expand Down Expand Up @@ -498,6 +501,78 @@ class LightningRepoTest : BaseUnitTest() {
assertTrue(result.isFailure)
}

@Test
fun `restartWithRgsServer should setup with new rgs server`() = test {
startNodeForTesting()
val customRgsUrl = "https://rgs.example.com/snapshot"
whenever(lightningService.node).thenReturn(null)
whenever(lightningService.stop()).thenReturn(Unit)

val result = sut.restartWithRgsServer(customRgsUrl)

assertTrue(result.isSuccess)
val inOrder = inOrder(lightningService)
inOrder.verify(lightningService).stop()
inOrder.verify(lightningService).setup(any(), isNull(), eq(customRgsUrl), anyOrNull(), anyOrNull())
inOrder.verify(lightningService).start(anyOrNull(), any())
assertEquals(NodeLifecycleState.Running, sut.lightningState.value.nodeLifecycleState)
}

@Test
fun `restartWithRgsServer should handle stop failure`() = test {
startNodeForTesting()
whenever(lightningService.stop()).thenThrow(RuntimeException("Stop failed"))

val result = sut.restartWithRgsServer("https://rgs.example.com/snapshot")

assertTrue(result.isFailure)
}

@Test
fun `restartWithRgsServer should handle start failure and recover`() = test {
startNodeForTesting()
whenever(lightningService.node).thenReturn(null)
whenever(lightningService.stop()).thenReturn(Unit)
whenever(lightningService.setup(any(), isNull(), eq("https://bad.rgs/snapshot"), anyOrNull(), anyOrNull()))
.thenThrow(RuntimeException("Failed to start node"))

val result = sut.restartWithRgsServer("https://bad.rgs/snapshot")

assertTrue(result.isFailure)
}

@Test
fun `restartWithRgsServer should fail when url is unreachable`() = test {
val failingValidator = UrlValidator { Result.failure(Exception("DNS resolution failed")) }
val sutWithFailingValidator = LightningRepo(
bgDispatcher = testDispatcher,
lightningService = lightningService,
settingsStore = settingsStore,
coreService = coreService,
lspNotificationsService = lspNotificationsService,
firebaseMessaging = firebaseMessaging,
keychain = keychain,
lnurlService = lnurlService,
cacheStore = cacheStore,
preActivityMetadataRepo = preActivityMetadataRepo,
connectivityRepo = connectivityRepo,
vssBackupClientLdk = vssBackupClientLdk,
urlValidator = failingValidator,
)
sutWithFailingValidator.setInitNodeLifecycleState()
whenever(lightningService.node).thenReturn(mock())
whenever(lightningService.sync()).thenReturn(Unit)
val blocktank = mock<BlocktankService>()
whenever(coreService.blocktank).thenReturn(blocktank)
whenever(blocktank.info(any())).thenReturn(null)
sutWithFailingValidator.start()

val result = sutWithFailingValidator.restartWithRgsServer("https://rapidsync.lightningdevkit/snapshot")

assertTrue(result.isFailure)
assertEquals("DNS resolution failed", result.exceptionOrNull()?.message)
}

@Test
fun `getFeeRateForSpeed should use provided feeRates`() = test {
val mockFeeRates = mock<FeeRates>()
Expand Down
Loading
Loading