Skip to content

Commit 3171897

Browse files
Restore Gemma4 screenshots and add LiteRT backend fallbacks
1 parent 02918f9 commit 3171897

2 files changed

Lines changed: 63 additions & 15 deletions

File tree

app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ enum class ModelOption(
6060
"gemma-4-e4b-it",
6161
ApiProvider.GOOGLE,
6262
"https://huggingface.co/litert-community/gemma-4-E4B-it-litert-lm/resolve/main/gemma-4-E4B-it.litertlm?download=true",
63-
supportsScreenshot = false,
6463
isOfflineModel = true,
6564
offlineModelFilename = "gemma-4-E4B-it.litertlm"
6665
),

app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -344,31 +344,23 @@ class PhotoReasoningViewModel(
344344
"modelPath=${modelFile.absolutePath}, modelSizeBytes=${modelFile.length()}"
345345
)
346346
if (liteRtEngine == null) {
347-
val liteRtBackend = if (backend == InferenceBackend.GPU) Backend.GPU() else Backend.CPU()
348-
val visionBackend = if (currentModel.supportsScreenshot) Backend.CPU() else null
347+
val preferredBackend = if (backend == InferenceBackend.GPU) Backend.GPU() else Backend.CPU()
348+
val preferredVisionBackend = if (currentModel.supportsScreenshot) Backend.GPU() else null
349349
val audioBackend = null
350350
val cacheDir =
351351
if (modelFile.absolutePath.startsWith("/data/local/tmp")) {
352352
context.getExternalFilesDir(null)?.absolutePath
353353
} else {
354354
null
355355
}
356-
val engineConfig = EngineConfig(
356+
liteRtEngine = createLiteRtEngineWithFallbacks(
357357
modelPath = modelFile.absolutePath,
358-
backend = liteRtBackend,
359-
visionBackend = visionBackend,
358+
preferredBackend = preferredBackend,
359+
preferredVisionBackend = preferredVisionBackend,
360360
audioBackend = audioBackend,
361-
maxNumTokens = null,
362361
cacheDir = cacheDir
363362
)
364-
Log.i(
365-
TAG,
366-
"Creating LiteRT engine with backend=$liteRtBackend, " +
367-
"visionBackend=$visionBackend, audioBackend=$audioBackend, " +
368-
"cacheDir=$cacheDir"
369-
)
370-
liteRtEngine = Engine(engineConfig).also { it.initialize() }
371-
Log.d(TAG, "Offline model initialized with LiteRT-LM Engine backend=$liteRtBackend")
363+
Log.d(TAG, "Offline model initialized with LiteRT-LM Engine")
372364
}
373365
} else {
374366
if (llmInference == null) {
@@ -421,6 +413,63 @@ class PhotoReasoningViewModel(
421413
val supportedAbis = Build.SUPPORTED_ABIS?.toSet().orEmpty()
422414
return supportedAbis.contains("arm64-v8a") || supportedAbis.contains("x86_64")
423415
}
416+
417+
private fun createLiteRtEngineWithFallbacks(
418+
modelPath: String,
419+
preferredBackend: Backend,
420+
preferredVisionBackend: Backend?,
421+
audioBackend: Backend?,
422+
cacheDir: String?
423+
): Engine {
424+
val cpuBackend = Backend.CPU()
425+
val gpuBackend = Backend.GPU()
426+
val attempts = linkedSetOf(
427+
preferredBackend to preferredVisionBackend,
428+
cpuBackend to preferredVisionBackend,
429+
cpuBackend to cpuBackend,
430+
gpuBackend to cpuBackend
431+
)
432+
var lastError: Exception? = null
433+
val failureDetails = StringBuilder()
434+
435+
attempts.forEachIndexed { index, (backendAttempt, visionAttempt) ->
436+
try {
437+
Log.i(
438+
TAG,
439+
"LiteRT init attempt ${index + 1}/${attempts.size}: " +
440+
"backend=$backendAttempt visionBackend=$visionAttempt audioBackend=$audioBackend cacheDir=$cacheDir"
441+
)
442+
val config = EngineConfig(
443+
modelPath = modelPath,
444+
backend = backendAttempt,
445+
visionBackend = visionAttempt,
446+
audioBackend = audioBackend,
447+
maxNumTokens = null,
448+
cacheDir = cacheDir
449+
)
450+
return Engine(config).also { it.initialize() }
451+
} catch (e: Exception) {
452+
lastError = e
453+
val msg = e.message ?: e.toString()
454+
failureDetails
455+
.append("Attempt ")
456+
.append(index + 1)
457+
.append(" failed (backend=")
458+
.append(backendAttempt)
459+
.append(", visionBackend=")
460+
.append(visionAttempt)
461+
.append("): ")
462+
.append(msg)
463+
.append('\n')
464+
Log.w(TAG, "LiteRT init attempt ${index + 1} failed", e)
465+
}
466+
}
467+
468+
throw IllegalStateException(
469+
"All LiteRT initialization attempts failed.\n$failureDetails",
470+
lastError
471+
)
472+
}
424473

425474
fun reinitializeOfflineModel(context: Context) {
426475
viewModelScope.launch(Dispatchers.IO) {

0 commit comments

Comments
 (0)