Use LiteRT to analyze an image: display a trivial message as a result

This commit is contained in:
Pierre-Yves Nicolas
2025-05-25 22:26:22 +02:00
parent d2b32b7527
commit 6f68bf05d6
8 changed files with 319 additions and 13 deletions

View File

@@ -0,0 +1,207 @@
package org.mydomain.myscan
import android.content.Context
import android.graphics.Bitmap
import android.graphics.Color
import android.os.SystemClock
import android.util.Log
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.BufferOverflow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.isActive
import kotlinx.coroutines.withContext
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ColorSpaceType
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.ImageProperties
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.Rot90Op
import java.nio.ByteBuffer
import java.nio.FloatBuffer
import java.util.Random
// TODO Review and remove unneeded code
class ImageSegmentationService(private val context: Context) {
companion object {
private const val TAG = "ImageSegmentation"
}
private val _segmentation = MutableStateFlow<SegmentationResult?>(null)
val segmentation: StateFlow<SegmentationResult?> = _segmentation.asStateFlow()
val error: SharedFlow<Throwable?>
get() = _error
private val _error = MutableSharedFlow<Throwable?>()
private var interpreter: Interpreter? = null
private val coloredLabels: List<ColoredLabel> = coloredLabels()
suspend fun initialize() {
interpreter = try {
val litertBuffer = FileUtil.loadMappedFile(context, "deeplab_v3.tflite")
Log.i(TAG, "Loaded LiteRT model")
val options = Interpreter.Options().apply {
numThreads = 2
}
Interpreter(litertBuffer, options)
} catch (e: Exception) {
Log.i(TAG, "Failed to load LiteRT model: ${e.message}")
_error.emit(e)
null
}
}
suspend fun runSegmentation(bitmap: Bitmap, rotationDegrees: Int) {
try {
withContext(Dispatchers.IO) {
if (interpreter == null) return@withContext
val startTime = SystemClock.uptimeMillis()
val rotation = -rotationDegrees / 90
val (_, h, w, _) = interpreter?.getOutputTensor(0)?.shape() ?: return@withContext
val imageProcessor =
ImageProcessor
.Builder()
.add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR))
.add(Rot90Op(rotation))
.add(NormalizeOp(127.5f, 127.5f))
.build()
// Preprocess the image and convert it into a TensorImage for segmentation.
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap))
val segmentResult = segment(tensorImage)
val inferenceTime = SystemClock.uptimeMillis() - startTime
if (isActive) {
_segmentation.value = SegmentationResult(segmentResult, inferenceTime)
}
}
} catch (e: Exception) {
Log.i(TAG, "Image segmentation error occurred: ${e.message}")
_error.emit(e)
}
}
private fun segment(tensorImage: TensorImage): Segmentation {
val (_, h, w, c) = interpreter!!.getOutputTensor(0).shape()
val outputBuffer = FloatBuffer.allocate(h * w * c)
outputBuffer.rewind()
interpreter?.run(tensorImage.tensorBuffer.buffer, outputBuffer)
outputBuffer.rewind()
val inferenceData =
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
val mask = processImage(inferenceData)
val imageProperties =
ImageProperties
.builder()
.setWidth(inferenceData.width)
.setHeight(inferenceData.height)
.setColorSpaceType(ColorSpaceType.GRAYSCALE)
.build()
val maskImage = TensorImage()
maskImage.load(mask, imageProperties)
return Segmentation(
listOf(maskImage), coloredLabels
)
}
private fun processImage(inferenceData: InferenceData): ByteBuffer {
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height)
for (i in 0 until inferenceData.height) {
for (j in 0 until inferenceData.width) {
val offset = inferenceData.channels * (i * inferenceData.width + j)
var maxIndex = 0
var maxValue = inferenceData.buffer.get(offset)
for (index in 1 until inferenceData.channels) {
if (inferenceData.buffer.get(offset + index) > maxValue) {
maxValue = inferenceData.buffer.get(offset + index)
maxIndex = index
}
}
mask.put(i * inferenceData.width + j, maxIndex.toByte())
}
}
return mask
}
private fun coloredLabels(): List<ColoredLabel> {
val labels = listOf(
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"dining table",
"dog",
"horse",
"motorbike",
"person",
"potted plant",
"sheep",
"sofa",
"train",
"tv",
"------"
)
val colors = MutableList(labels.size) {
ColoredLabel(
labels[0], "", Color.BLACK
)
}
val random = Random()
val goldenRatioConjugate = 0.618033988749895
var hue = random.nextDouble()
// Skip the first label as it's already assigned black
for (idx in 1 until labels.size) {
hue += goldenRatioConjugate
hue %= 1.0
// Adjust saturation & lightness as needed
val color = Color.HSVToColor(floatArrayOf(hue.toFloat() * 360, 0.7f, 0.8f))
colors[idx] = ColoredLabel(labels[idx], "", color)
}
return colors
}
data class Segmentation(
val masks: List<TensorImage>,
val coloredLabels: List<ColoredLabel>,
)
data class ColoredLabel(val label: String, val displayName: String, val argb: Int)
data class SegmentationResult(
val segmentation: Segmentation,
val inferenceTime: Long
)
data class InferenceData(
val width: Int,
val height: Int,
val channels: Int,
val buffer: FloatBuffer,
)
}

View File

@@ -5,37 +5,41 @@ import android.util.Log
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge
import androidx.activity.viewModels
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import androidx.lifecycle.compose.collectAsStateWithLifecycle
import org.mydomain.myscan.ui.theme.MyScanTheme
import org.mydomain.myscan.view.CameraScreen
import java.util.Date
class MainActivity : ComponentActivity() {
companion object {
private const val TAG = "MyScan"
}
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
val viewModel: MainViewModel by viewModels { MainViewModel.getFactory(this) }
enableEdgeToEdge()
setContent {
val uiState by viewModel.uiState.collectAsStateWithLifecycle()
Log.d("MyScan", "!!"+uiState.toString())
MyScanTheme {
Scaffold(/*modifier = Modifier.fillMaxSize()*/) { innerPadding ->
Scaffold { innerPadding ->
Column {
Greeting(modifier = Modifier.padding(innerPadding))
Box(/*modifier = Modifier.width(300.dp)*/) {
CameraScreen(onImageAnalyzed = { image ->
Log.d(TAG, Date().toString())
image.close()
} )
MyMessageBox(uiState.detectionMessage, uiState.inferenceTime)
Box {
CameraScreen(onImageAnalyzed = { image -> viewModel.segment(image) } )
}
}
}
@@ -44,6 +48,20 @@ class MainActivity : ComponentActivity() {
}
}
@Composable
fun MyMessageBox(msg: String?, inferenceTime: Long) {
Log.d("MyScan", "MyMessageBox recompose: $msg")
Text(
text = (msg ?: "") + " inferred in " + inferenceTime + "ms",
modifier = Modifier
.padding(16.dp)
.background(Color.Yellow)
.fillMaxWidth(),
color = Color.Black,
fontSize = 20.sp
)
}
@Composable
fun Greeting(modifier: Modifier = Modifier) {
Text(
@@ -54,8 +72,8 @@ fun Greeting(modifier: Modifier = Modifier) {
@Preview(showBackground = true)
@Composable
fun GreetingPreview() {
fun MyMessageBoxPreview() {
MyScanTheme {
Greeting()
MyMessageBox("Found 2 objects!", 42)
}
}

View File

@@ -0,0 +1,56 @@
package org.mydomain.myscan
import android.content.Context
import android.util.Log
import androidx.camera.core.ImageProxy
import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider
import androidx.lifecycle.viewModelScope
import androidx.lifecycle.viewmodel.CreationExtras
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
class MainViewModel(private val imageSegmentationService: ImageSegmentationService): ViewModel() {
companion object {
fun getFactory(context: Context) = object : ViewModelProvider.Factory {
override fun <T : ViewModel> create(modelClass: Class<T>, extras: CreationExtras): T {
return MainViewModel(ImageSegmentationService(context)) as T
}
}
}
private var _uiState = MutableStateFlow(UiState("just started"))
val uiState: StateFlow<UiState> = _uiState.asStateFlow()
init {
viewModelScope.launch {
imageSegmentationService.initialize()
imageSegmentationService.segmentation
.filterNotNull()
.map { UiState("Found ${it.segmentation.masks.size} objects!", it.inferenceTime) }
.collect {
Log.d("MyScan", "New UIstate ${it}")
_uiState.value = it
}
}
}
fun segment(imageProxy: ImageProxy) {
Log.d("MyScan", "MainViewModel.Calling segment")
viewModelScope.launch {
imageSegmentationService.runSegmentation(
imageProxy.toBitmap(),
imageProxy.imageInfo.rotationDegrees,
)
imageProxy.close()
}
}
}

View File

@@ -0,0 +1,10 @@
package org.mydomain.myscan
import androidx.compose.runtime.Immutable
@Immutable
data class UiState(
val detectionMessage: String? = null,
val inferenceTime: Long = 0L,
val errorMessage: String? = null,
)