From 6f68bf05d61c177c9bc4faddf64e82994705a1c7 Mon Sep 17 00:00:00 2001 From: Pierre-Yves Nicolas <6371790+pynicolas@users.noreply.github.com> Date: Sun, 25 May 2025 22:26:22 +0200 Subject: [PATCH] Use LiteRT to analyze an image: display a trivial message as a result --- .gitignore | 1 + .idea/deploymentTargetSelector.xml | 3 + app/build.gradle.kts | 5 + .../myscan/ImageSegmentationService.kt | 207 ++++++++++++++++++ .../java/org/mydomain/myscan/MainActivity.kt | 44 ++-- .../java/org/mydomain/myscan/MainViewModel.kt | 56 +++++ .../main/java/org/mydomain/myscan/UiState.kt | 10 + gradle/libs.versions.toml | 6 + 8 files changed, 319 insertions(+), 13 deletions(-) create mode 100644 app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt create mode 100644 app/src/main/java/org/mydomain/myscan/MainViewModel.kt create mode 100644 app/src/main/java/org/mydomain/myscan/UiState.kt diff --git a/.gitignore b/.gitignore index aa724b7..dcbd24a 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ .externalNativeBuild .cxx local.properties +*.tflite diff --git a/.idea/deploymentTargetSelector.xml b/.idea/deploymentTargetSelector.xml index 762faed..652219b 100644 --- a/.idea/deploymentTargetSelector.xml +++ b/.idea/deploymentTargetSelector.xml @@ -13,6 +13,9 @@ + + \ No newline at end of file diff --git a/app/build.gradle.kts b/app/build.gradle.kts index e462101..7eac30c 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -43,6 +43,8 @@ dependencies { implementation(libs.androidx.core.ktx) implementation(libs.androidx.lifecycle.runtime.ktx) + implementation(libs.androidx.lifecycle.runtime.compose) + implementation(libs.androidx.lifecycle.viewmodel.compose) implementation(libs.androidx.activity.compose) implementation(platform(libs.androidx.compose.bom)) implementation(libs.androidx.ui) @@ -53,6 +55,9 @@ dependencies { implementation(libs.androidx.camera.camera2) implementation(libs.androidx.camera.lifecycle) implementation(libs.androidx.camera.view) + implementation(libs.litert) + implementation(libs.litert.support) + implementation(libs.litert.metadata) testImplementation(libs.junit) androidTestImplementation(libs.androidx.junit) diff --git a/app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt b/app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt new file mode 100644 index 0000000..6b7d59c --- /dev/null +++ b/app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt @@ -0,0 +1,207 @@ +package org.mydomain.myscan + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.Color +import android.os.SystemClock +import android.util.Log +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.channels.BufferOverflow +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.isActive +import kotlinx.coroutines.withContext +import org.tensorflow.lite.Interpreter +import org.tensorflow.lite.support.common.FileUtil +import org.tensorflow.lite.support.common.ops.NormalizeOp +import org.tensorflow.lite.support.image.ColorSpaceType +import org.tensorflow.lite.support.image.ImageProcessor +import org.tensorflow.lite.support.image.ImageProperties +import org.tensorflow.lite.support.image.TensorImage +import org.tensorflow.lite.support.image.ops.ResizeOp +import org.tensorflow.lite.support.image.ops.Rot90Op +import java.nio.ByteBuffer +import java.nio.FloatBuffer +import java.util.Random + +// TODO Review and remove unneeded code +class ImageSegmentationService(private val context: Context) { + + companion object { + private const val TAG = "ImageSegmentation" + } + + private val _segmentation = MutableStateFlow(null) + val segmentation: StateFlow = _segmentation.asStateFlow() + + val error: SharedFlow + get() = _error + private val _error = MutableSharedFlow() + + private var interpreter: Interpreter? = null + + private val coloredLabels: List = coloredLabels() + + suspend fun initialize() { + interpreter = try { + val litertBuffer = FileUtil.loadMappedFile(context, "deeplab_v3.tflite") + Log.i(TAG, "Loaded LiteRT model") + val options = Interpreter.Options().apply { + numThreads = 2 + } + Interpreter(litertBuffer, options) + } catch (e: Exception) { + Log.i(TAG, "Failed to load LiteRT model: ${e.message}") + _error.emit(e) + null + } + } + + suspend fun runSegmentation(bitmap: Bitmap, rotationDegrees: Int) { + try { + withContext(Dispatchers.IO) { + if (interpreter == null) return@withContext + val startTime = SystemClock.uptimeMillis() + + val rotation = -rotationDegrees / 90 + val (_, h, w, _) = interpreter?.getOutputTensor(0)?.shape() ?: return@withContext + val imageProcessor = + ImageProcessor + .Builder() + .add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR)) + .add(Rot90Op(rotation)) + .add(NormalizeOp(127.5f, 127.5f)) + .build() + + // Preprocess the image and convert it into a TensorImage for segmentation. + val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap)) + val segmentResult = segment(tensorImage) + val inferenceTime = SystemClock.uptimeMillis() - startTime + if (isActive) { + _segmentation.value = SegmentationResult(segmentResult, inferenceTime) + } + } + } catch (e: Exception) { + Log.i(TAG, "Image segmentation error occurred: ${e.message}") + _error.emit(e) + } + } + + + private fun segment(tensorImage: TensorImage): Segmentation { + val (_, h, w, c) = interpreter!!.getOutputTensor(0).shape() + val outputBuffer = FloatBuffer.allocate(h * w * c) + + outputBuffer.rewind() + interpreter?.run(tensorImage.tensorBuffer.buffer, outputBuffer) + + outputBuffer.rewind() + val inferenceData = + InferenceData(width = w, height = h, channels = c, buffer = outputBuffer) + val mask = processImage(inferenceData) + + val imageProperties = + ImageProperties + .builder() + .setWidth(inferenceData.width) + .setHeight(inferenceData.height) + .setColorSpaceType(ColorSpaceType.GRAYSCALE) + .build() + val maskImage = TensorImage() + maskImage.load(mask, imageProperties) + return Segmentation( + listOf(maskImage), coloredLabels + ) + } + + private fun processImage(inferenceData: InferenceData): ByteBuffer { + val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height) + for (i in 0 until inferenceData.height) { + for (j in 0 until inferenceData.width) { + val offset = inferenceData.channels * (i * inferenceData.width + j) + + var maxIndex = 0 + var maxValue = inferenceData.buffer.get(offset) + + for (index in 1 until inferenceData.channels) { + if (inferenceData.buffer.get(offset + index) > maxValue) { + maxValue = inferenceData.buffer.get(offset + index) + maxIndex = index + } + } + + mask.put(i * inferenceData.width + j, maxIndex.toByte()) + } + } + + return mask + } + private fun coloredLabels(): List { + val labels = listOf( + "background", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "dining table", + "dog", + "horse", + "motorbike", + "person", + "potted plant", + "sheep", + "sofa", + "train", + "tv", + "------" + ) + val colors = MutableList(labels.size) { + ColoredLabel( + labels[0], "", Color.BLACK + ) + } + + val random = Random() + val goldenRatioConjugate = 0.618033988749895 + var hue = random.nextDouble() + + // Skip the first label as it's already assigned black + for (idx in 1 until labels.size) { + hue += goldenRatioConjugate + hue %= 1.0 + // Adjust saturation & lightness as needed + val color = Color.HSVToColor(floatArrayOf(hue.toFloat() * 360, 0.7f, 0.8f)) + colors[idx] = ColoredLabel(labels[idx], "", color) + } + + return colors + } + + data class Segmentation( + val masks: List, + val coloredLabels: List, + ) + + data class ColoredLabel(val label: String, val displayName: String, val argb: Int) + + data class SegmentationResult( + val segmentation: Segmentation, + val inferenceTime: Long + ) + + data class InferenceData( + val width: Int, + val height: Int, + val channels: Int, + val buffer: FloatBuffer, + ) +} \ No newline at end of file diff --git a/app/src/main/java/org/mydomain/myscan/MainActivity.kt b/app/src/main/java/org/mydomain/myscan/MainActivity.kt index 71fe683..58d704b 100644 --- a/app/src/main/java/org/mydomain/myscan/MainActivity.kt +++ b/app/src/main/java/org/mydomain/myscan/MainActivity.kt @@ -5,37 +5,41 @@ import android.util.Log import androidx.activity.ComponentActivity import androidx.activity.compose.setContent import androidx.activity.enableEdgeToEdge +import androidx.activity.viewModels +import androidx.compose.foundation.background import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding import androidx.compose.material3.Scaffold import androidx.compose.material3.Text import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.Color import androidx.compose.ui.tooling.preview.Preview +import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp +import androidx.lifecycle.compose.collectAsStateWithLifecycle import org.mydomain.myscan.ui.theme.MyScanTheme import org.mydomain.myscan.view.CameraScreen -import java.util.Date class MainActivity : ComponentActivity() { - companion object { - private const val TAG = "MyScan" - } - override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) + val viewModel: MainViewModel by viewModels { MainViewModel.getFactory(this) } enableEdgeToEdge() setContent { + val uiState by viewModel.uiState.collectAsStateWithLifecycle() + Log.d("MyScan", "!!"+uiState.toString()) MyScanTheme { - Scaffold(/*modifier = Modifier.fillMaxSize()*/) { innerPadding -> + Scaffold { innerPadding -> Column { Greeting(modifier = Modifier.padding(innerPadding)) - Box(/*modifier = Modifier.width(300.dp)*/) { - CameraScreen(onImageAnalyzed = { image -> - Log.d(TAG, Date().toString()) - image.close() - } ) + MyMessageBox(uiState.detectionMessage, uiState.inferenceTime) + Box { + CameraScreen(onImageAnalyzed = { image -> viewModel.segment(image) } ) } } } @@ -44,6 +48,20 @@ class MainActivity : ComponentActivity() { } } +@Composable +fun MyMessageBox(msg: String?, inferenceTime: Long) { + Log.d("MyScan", "MyMessageBox recompose: $msg") + Text( + text = (msg ?: "") + " inferred in " + inferenceTime + "ms", + modifier = Modifier + .padding(16.dp) + .background(Color.Yellow) + .fillMaxWidth(), + color = Color.Black, + fontSize = 20.sp + ) +} + @Composable fun Greeting(modifier: Modifier = Modifier) { Text( @@ -54,8 +72,8 @@ fun Greeting(modifier: Modifier = Modifier) { @Preview(showBackground = true) @Composable -fun GreetingPreview() { +fun MyMessageBoxPreview() { MyScanTheme { - Greeting() + MyMessageBox("Found 2 objects!", 42) } } \ No newline at end of file diff --git a/app/src/main/java/org/mydomain/myscan/MainViewModel.kt b/app/src/main/java/org/mydomain/myscan/MainViewModel.kt new file mode 100644 index 0000000..cec7ac7 --- /dev/null +++ b/app/src/main/java/org/mydomain/myscan/MainViewModel.kt @@ -0,0 +1,56 @@ +package org.mydomain.myscan + +import android.content.Context +import android.util.Log +import androidx.camera.core.ImageProxy +import androidx.lifecycle.ViewModel +import androidx.lifecycle.ViewModelProvider +import androidx.lifecycle.viewModelScope +import androidx.lifecycle.viewmodel.CreationExtras +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.filterNotNull +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.update +import kotlinx.coroutines.launch + +class MainViewModel(private val imageSegmentationService: ImageSegmentationService): ViewModel() { + + companion object { + fun getFactory(context: Context) = object : ViewModelProvider.Factory { + override fun create(modelClass: Class, extras: CreationExtras): T { + return MainViewModel(ImageSegmentationService(context)) as T + } + } + } + + private var _uiState = MutableStateFlow(UiState("just started")) + val uiState: StateFlow = _uiState.asStateFlow() + + init { + viewModelScope.launch { + imageSegmentationService.initialize() + imageSegmentationService.segmentation + .filterNotNull() + .map { UiState("Found ${it.segmentation.masks.size} objects!", it.inferenceTime) } + .collect { + Log.d("MyScan", "New UIstate ${it}") + _uiState.value = it + } + } + } + + + fun segment(imageProxy: ImageProxy) { + Log.d("MyScan", "MainViewModel.Calling segment") + viewModelScope.launch { + imageSegmentationService.runSegmentation( + imageProxy.toBitmap(), + imageProxy.imageInfo.rotationDegrees, + ) + imageProxy.close() + } + } + +} \ No newline at end of file diff --git a/app/src/main/java/org/mydomain/myscan/UiState.kt b/app/src/main/java/org/mydomain/myscan/UiState.kt new file mode 100644 index 0000000..817fd54 --- /dev/null +++ b/app/src/main/java/org/mydomain/myscan/UiState.kt @@ -0,0 +1,10 @@ +package org.mydomain.myscan + +import androidx.compose.runtime.Immutable + +@Immutable +data class UiState( + val detectionMessage: String? = null, + val inferenceTime: Long = 0L, + val errorMessage: String? = null, +) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 87e83da..94574a7 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -9,6 +9,7 @@ lifecycleRuntimeKtx = "2.9.0" activityCompose = "1.10.1" composeBom = "2025.05.00" camerax = "1.4.2" +litert = "1.2.0" [libraries] androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } @@ -16,6 +17,8 @@ junit = { group = "junit", name = "junit", version.ref = "junit" } androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "junitVersion" } androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espressoCore" } androidx-lifecycle-runtime-ktx = { group = "androidx.lifecycle", name = "lifecycle-runtime-ktx", version.ref = "lifecycleRuntimeKtx" } +androidx-lifecycle-viewmodel-compose = { group = "androidx.lifecycle", name = "lifecycle-viewmodel-compose", version.ref = "lifecycleRuntimeKtx" } +androidx-lifecycle-runtime-compose = { group = "androidx.lifecycle", name = "lifecycle-runtime-compose", version.ref = "lifecycleRuntimeKtx" } androidx-activity-compose = { group = "androidx.activity", name = "activity-compose", version.ref = "activityCompose" } androidx-compose-bom = { group = "androidx.compose", name = "compose-bom", version.ref = "composeBom" } androidx-ui = { group = "androidx.compose.ui", name = "ui" } @@ -30,6 +33,9 @@ androidx-camera-core = { group = "androidx.camera", name = "camera-core", versio androidx-camera-camera2 = { group = "androidx.camera", name = "camera-camera2", version.ref = "camerax" } androidx-camera-lifecycle = { group = "androidx.camera", name = "camera-lifecycle", version.ref = "camerax" } androidx-camera-view = { group = "androidx.camera", name = "camera-view", version.ref = "camerax" } +litert = { group = "com.google.ai.edge.litert", name = "litert", version.ref = "litert" } +litert-support = { group = "com.google.ai.edge.litert", name = "litert-support", version.ref = "litert" } +litert-metadata = { group = "com.google.ai.edge.litert", name = "litert-metadata", version.ref = "litert" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" }