Display overlay for the segmentation

This commit is contained in:
Pierre-Yves Nicolas
2025-05-27 09:38:23 +02:00
parent 95bd3bd823
commit 3457a85044
5 changed files with 70 additions and 11 deletions

View File

@@ -2,11 +2,11 @@ package org.mydomain.myscan
import android.content.Context import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.graphics.Color import android.graphics.Bitmap.createBitmap
import android.graphics.Color.argb
import android.os.SystemClock import android.os.SystemClock
import android.util.Log import android.util.Log
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.channels.BufferOverflow
import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharedFlow import kotlinx.coroutines.flow.SharedFlow
@@ -25,7 +25,6 @@ import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.Rot90Op import org.tensorflow.lite.support.image.ops.Rot90Op
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.FloatBuffer import java.nio.FloatBuffer
import java.util.Random
// TODO Review and remove unneeded code // TODO Review and remove unneeded code
class ImageSegmentationService(private val context: Context) { class ImageSegmentationService(private val context: Context) {
@@ -110,7 +109,7 @@ class ImageSegmentationService(private val context: Context) {
.build() .build()
val maskImage = TensorImage() val maskImage = TensorImage()
maskImage.load(mask, imageProperties) maskImage.load(mask, imageProperties)
return Segmentation(listOf(maskImage)) return Segmentation(maskImage)
} }
private fun processImage(inferenceData: InferenceData): ByteBuffer { private fun processImage(inferenceData: InferenceData): ByteBuffer {
@@ -136,9 +135,23 @@ class ImageSegmentationService(private val context: Context) {
return mask return mask
} }
data class Segmentation( data class Segmentation(val mask: TensorImage) {
val masks: List<TensorImage> fun toBitmap(): Bitmap {
) val width = mask.width
val height = mask.height
val pixels = IntArray(width * height)
val green = argb(128, 0, 255, 0)
for (i in 0 until height) {
for (j in 0 until width) {
val index = i * width + j
val classId = mask.buffer[index].toInt() and 0xFF // Unsigned byte
pixels[index] = if (classId == 0) 0 else green
}
}
return createBitmap(pixels, width, height, Bitmap.Config.ARGB_8888)
}
}
data class SegmentationResult( data class SegmentationResult(
val segmentation: Segmentation, val segmentation: Segmentation,

View File

@@ -5,6 +5,7 @@ import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge import androidx.activity.enableEdgeToEdge
import androidx.activity.viewModels import androidx.activity.viewModels
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.fillMaxWidth
@@ -35,7 +36,7 @@ class MainActivity : ComponentActivity() {
Greeting(modifier = Modifier.padding(innerPadding)) Greeting(modifier = Modifier.padding(innerPadding))
MyMessageBox(uiState.detectionMessage, uiState.inferenceTime) MyMessageBox(uiState.detectionMessage, uiState.inferenceTime)
Box { Box {
CameraScreen(onImageAnalyzed = { image -> viewModel.segment(image) } ) CameraScreen(uiState, onImageAnalyzed = { image -> viewModel.segment(image) } )
} }
} }
} }
@@ -50,6 +51,7 @@ fun MyMessageBox(msg: String?, inferenceTime: Long) {
text = (msg ?: "") + " / inferred in " + inferenceTime + "ms", text = (msg ?: "") + " / inferred in " + inferenceTime + "ms",
modifier = Modifier modifier = Modifier
.padding(16.dp) .padding(16.dp)
.background(Color.Gray)
.fillMaxWidth(), .fillMaxWidth(),
color = Color.Black, color = Color.Black,
) )

View File

@@ -32,7 +32,12 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
imageSegmentationService.initialize() imageSegmentationService.initialize()
imageSegmentationService.segmentation imageSegmentationService.segmentation
.filterNotNull() .filterNotNull()
.map { UiState("Found ${numberOfObjectsDetected(it.segmentation)} objects!", it.inferenceTime) } .map {
UiState(
"Found ${numberOfObjectsDetected(it.segmentation)} objects!",
it.inferenceTime,
it.segmentation.toBitmap())
}
.collect { .collect {
Log.d("MyScan", "New UIstate ${it}") Log.d("MyScan", "New UIstate ${it}")
_uiState.value = it _uiState.value = it
@@ -41,7 +46,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
} }
fun numberOfObjectsDetected(segmentation: ImageSegmentationService.Segmentation) : Int { fun numberOfObjectsDetected(segmentation: ImageSegmentationService.Segmentation) : Int {
val tensor = segmentation.masks[0]; val tensor = segmentation.mask;
val buffer = tensor.buffer val buffer = tensor.buffer
val uniqueValues = HashSet<Int>() val uniqueValues = HashSet<Int>()
for (i in 0..tensor.width * tensor.height - 1) { for (i in 0..tensor.width * tensor.height - 1) {

View File

@@ -1,10 +1,12 @@
package org.mydomain.myscan package org.mydomain.myscan
import android.graphics.Bitmap
import androidx.compose.runtime.Immutable import androidx.compose.runtime.Immutable
@Immutable @Immutable
data class UiState( data class UiState(
val detectionMessage: String? = null, val detectionMessage: String? = null,
val inferenceTime: Long = 0L, val inferenceTime: Long = 0L,
val overlayBitmap: Bitmap? = null,
val errorMessage: String? = null, val errorMessage: String? = null,
) )

View File

@@ -1,6 +1,7 @@
package org.mydomain.myscan.view package org.mydomain.myscan.view
import android.content.pm.PackageManager.PERMISSION_GRANTED import android.content.pm.PackageManager.PERMISSION_GRANTED
import android.graphics.Bitmap
import android.view.ViewGroup.LayoutParams.MATCH_PARENT import android.view.ViewGroup.LayoutParams.MATCH_PARENT
import android.widget.LinearLayout import android.widget.LinearLayout
import android.widget.Toast import android.widget.Toast
@@ -13,6 +14,11 @@ import androidx.camera.core.ImageProxy
import androidx.camera.core.Preview import androidx.camera.core.Preview
import androidx.camera.lifecycle.ProcessCameraProvider import androidx.camera.lifecycle.ProcessCameraProvider
import androidx.camera.view.PreviewView import androidx.camera.view.PreviewView
import androidx.compose.foundation.Canvas
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.width
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.DisposableEffect import androidx.compose.runtime.DisposableEffect
import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.LaunchedEffect
@@ -20,17 +26,22 @@ import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.platform.LocalConfiguration
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.compose.ui.viewinterop.AndroidView import androidx.compose.ui.viewinterop.AndroidView
import androidx.core.content.ContextCompat import androidx.core.content.ContextCompat
import androidx.lifecycle.LifecycleOwner import androidx.lifecycle.LifecycleOwner
import androidx.lifecycle.compose.LocalLifecycleOwner import androidx.lifecycle.compose.LocalLifecycleOwner
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import org.mydomain.myscan.UiState
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors import java.util.concurrent.Executors
@Composable @Composable
fun CameraScreen( fun CameraScreen(
uiState: UiState,
onImageAnalyzed: (ImageProxy) -> Unit, onImageAnalyzed: (ImageProxy) -> Unit,
) { ) {
// TODO Check the errors in the logs before the user gives the required authorization // TODO Check the errors in the logs before the user gives the required authorization
@@ -50,7 +61,21 @@ fun CameraScreen(
} }
} }
val width = LocalConfiguration.current.screenWidthDp
val height = width / 3 * 4
Box(
modifier = Modifier
.width(width.dp)
.height(height.dp)
) {
CameraPreview(onImageAnalyzed = onImageAnalyzed) CameraPreview(onImageAnalyzed = onImageAnalyzed)
if (uiState.overlayBitmap != null) {
SegmentationOverlay(
modifier = Modifier.fillMaxSize(),
overlay = uiState.overlayBitmap
)
}
}
} }
@Composable @Composable
@@ -112,3 +137,15 @@ fun bindCameraUseCases(
cameraProvider.bindToLifecycle(lifecycleOwner, cameraSelector, imageAnalysis, preview) cameraProvider.bindToLifecycle(lifecycleOwner, cameraSelector, imageAnalysis, preview)
} }
@Composable
fun SegmentationOverlay(modifier: Modifier = Modifier, overlay: Bitmap) {
Canvas(
modifier = modifier
) {
val imageWidth: Float = size.width
val imageHeight: Float = size.height
val scaleBitmap =
Bitmap.createScaledBitmap(overlay, imageWidth.toInt(), imageHeight.toInt(), true)
drawImage(scaleBitmap.asImageBitmap())
}
}