From 0c5a219783ca4b520d8a164f7ca392947fe88c91 Mon Sep 17 00:00:00 2001 From: Pierre-Yves Nicolas <6371790+pynicolas@users.noreply.github.com> Date: Thu, 19 Jun 2025 15:00:20 +0200 Subject: [PATCH] New segmentation model --- .../org/mydomain/myscan/ImageSegmentation.kt | 154 +++++------------- .../java/org/mydomain/myscan/MainViewModel.kt | 7 + .../java/org/mydomain/myscan/view/Camera.kt | 2 + 3 files changed, 52 insertions(+), 111 deletions(-) diff --git a/app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt b/app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt index a0942a6..9116250 100644 --- a/app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt +++ b/app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt @@ -18,23 +18,24 @@ import android.content.Context import android.graphics.Bitmap import android.graphics.Bitmap.createBitmap import android.graphics.Color -import android.graphics.Matrix import android.os.SystemClock import android.util.Log -import androidx.core.graphics.scale import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.isActive import kotlinx.coroutines.withContext +import org.tensorflow.lite.DataType import org.tensorflow.lite.Interpreter import org.tensorflow.lite.support.common.FileUtil -import org.tensorflow.lite.support.image.ColorSpaceType -import org.tensorflow.lite.support.image.ImageProperties +import org.tensorflow.lite.support.common.ops.NormalizeOp +import org.tensorflow.lite.support.image.ImageProcessor import org.tensorflow.lite.support.image.TensorImage +import org.tensorflow.lite.support.image.ops.ResizeOp +import org.tensorflow.lite.support.image.ops.Rot90Op import java.nio.ByteBuffer -import java.nio.FloatBuffer +import java.nio.ByteOrder class ImageSegmentationService(private val context: Context) { @@ -49,7 +50,7 @@ class ImageSegmentationService(private val context: Context) { fun initialize() { interpreter = try { - val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite") + val litertBuffer = FileUtil.loadMappedFile(context, "timm_efficientnet_lite0_quantized.tflite") Log.i(TAG, "Loaded LiteRT model") val options = Interpreter.Options().apply { numThreads = 2 @@ -64,19 +65,21 @@ class ImageSegmentationService(private val context: Context) { private fun runSegmentation(interpreter: Interpreter, bitmap: Bitmap, rotationDegrees: Int): SegmentationResult { val startTime = SystemClock.uptimeMillis() - val (_, _, h, w) = interpreter.getInputTensor(0).shape() - // Preprocess manually into CHW float buffer - val inputBuffer = bitmapToCHWFloatBuffer(bitmap, width = w, height = h, rotationDegrees) - - val (_, cOut, hOut, wOut) = interpreter.getOutputTensor(0).shape() - val outputBuffer = FloatBuffer.allocate(cOut * hOut * wOut) - - // Run inference - outputBuffer.rewind() - interpreter.run(inputBuffer, outputBuffer) + val rotation = -rotationDegrees / 90 + val (_, h, w, _) = interpreter.getOutputTensor(0).shape() + val imageProcessor = + ImageProcessor + .Builder() + .add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR)) + .add(Rot90Op(rotation)) + .add(NormalizeOp(127.5f, 127.5f)) // TODO check if it's correct + .build() + val tensorImage = TensorImage(DataType.FLOAT32) + tensorImage.load(bitmap) + val processedImage = imageProcessor.process(tensorImage) + val segmentResult = segment(interpreter, processedImage) val inferenceTime = SystemClock.uptimeMillis() - startTime - val segmentResult = processOutputBuffer(outputBuffer, wOut, hOut, cOut) return SegmentationResult(segmentResult, inferenceTime) } @@ -101,111 +104,40 @@ class ImageSegmentationService(private val context: Context) { } } - fun bitmapToCHWFloatBuffer(bitmap: Bitmap, width: Int, height: Int, rotationDegrees: Int): FloatBuffer { - val rotatedBitmap = if (rotationDegrees != 0) { - val matrix = Matrix().apply { postRotate(rotationDegrees.toFloat()) } - createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true) - } else { - bitmap - } + private fun segment(interpreter: Interpreter, tensorImage: TensorImage): Segmentation { + val (_, h, w, _) = interpreter.getOutputTensor(0).shape() + val outputBuffer = ByteBuffer.allocateDirect(4 * h * w) + outputBuffer.order(ByteOrder.nativeOrder()) + outputBuffer.rewind() + interpreter.run(tensorImage.tensorBuffer.buffer, outputBuffer) + outputBuffer.rewind() + val mask = generateMaskFromOutputBuffer(outputBuffer, w, h) + return Segmentation(mask) + } - val resized = rotatedBitmap.scale(width, height) - val buffer = FloatBuffer.allocate(1 * 3 * height * width) - buffer.rewind() - - val mean = floatArrayOf(0.4611f, 0.4359f, 0.3905f) - val std = floatArrayOf(0.2193f, 0.2150f, 0.2109f) + private fun generateMaskFromOutputBuffer(outputBuffer: ByteBuffer, width: Int, height: Int): Bitmap { + outputBuffer.rewind() + val floatArray = FloatArray(width * height) + outputBuffer.asFloatBuffer()[floatArray] val pixels = IntArray(width * height) - resized.getPixels(pixels, 0, width, 0, 0, width, height) - - // Fill buffer in CHW order - for (c in 0..2) { - for (i in 0 until height) { - for (j in 0 until width) { - val pixel = pixels[i * width + j] - val value = when (c) { - 0 -> (pixel shr 16 and 0xFF) // R - 1 -> (pixel shr 8 and 0xFF) // G - 2 -> (pixel and 0xFF) // B - else -> 0 - } - val normalized = (value / 255f - mean[c]) / std[c] - buffer.put(normalized) - } - } + for (i in floatArray.indices) { + val value = floatArray[i].coerceIn(0f, 1f) + val gray = (value * 255).toInt() + pixels[i] = Color.rgb(gray, gray, gray) } - buffer.rewind() - return buffer + val bitmap = createBitmap(width, height, Bitmap.Config.ARGB_8888) + bitmap.setPixels(pixels, 0, width, 0, 0, width, height) + return bitmap } - private fun processOutputBuffer(outputBuffer: FloatBuffer, w: Int, h: Int, c: Int): Segmentation { - outputBuffer.rewind() - val inferenceData = - InferenceData(width = w, height = h, channels = c, buffer = outputBuffer) - val mask = generateMaskFromOutputBuffer(inferenceData) - - val imageProperties = - ImageProperties - .builder() - .setWidth(inferenceData.width) - .setHeight(inferenceData.height) - .setColorSpaceType(ColorSpaceType.GRAYSCALE) - .build() - val maskImage = TensorImage() - maskImage.load(mask, imageProperties) - return Segmentation(maskImage) - } - - private fun generateMaskFromOutputBuffer(inferenceData: InferenceData): ByteBuffer { - val width = inferenceData.width - val height = inferenceData.height - val mask = ByteBuffer.allocateDirect(width * height) - for (i in 0 until height) { - for (j in 0 until width) { - var maxIndex = 0 - var maxValue = inferenceData.buffer[i * width + j] - - for (c in 1 until inferenceData.channels) { - val value = inferenceData.buffer[c * height * width + i * width + j] - if (value > maxValue) { - maxValue = value - maxIndex = c - } - } - - mask.put(i * width + j, maxIndex.toByte()) - } - } - return mask - } - - data class Segmentation(val mask: TensorImage) { - fun toBinaryMask(): Bitmap { - val width = mask.width - val height = mask.height - val pixels = IntArray(width * height) - for (i in 0 until height) { - for (j in 0 until width) { - val index = i * width + j - val classId = mask.buffer[index].toInt() and 0xFF // Unsigned byte - pixels[index] = if (classId == 0) Color.BLACK else Color.WHITE - } - } - return createBitmap(pixels, width, height, Bitmap.Config.ARGB_8888) - } + data class Segmentation(val mask: Bitmap) { + fun toBinaryMask(): Bitmap = mask } data class SegmentationResult( val segmentation: Segmentation, val inferenceTime: Long ) - - data class InferenceData( - val width: Int, - val height: Int, - val channels: Int, - val buffer: FloatBuffer, - ) } diff --git a/app/src/main/java/org/mydomain/myscan/MainViewModel.kt b/app/src/main/java/org/mydomain/myscan/MainViewModel.kt index 42724f1..d0f7c2d 100644 --- a/app/src/main/java/org/mydomain/myscan/MainViewModel.kt +++ b/app/src/main/java/org/mydomain/myscan/MainViewModel.kt @@ -60,6 +60,8 @@ class MainViewModel( private var _pageToValidate = MutableStateFlow(null) val pageToValidate: StateFlow = _pageToValidate.asStateFlow() + var liveAnalysisEnabled = true + init { viewModelScope.launch { imageSegmentationService.initialize() @@ -80,6 +82,11 @@ class MainViewModel( } fun segment(imageProxy: ImageProxy) { + if (!liveAnalysisEnabled) { + imageProxy.close() + return + } + viewModelScope.launch { imageSegmentationService.runSegmentationAndEmit( imageProxy.toBitmap(), diff --git a/app/src/main/java/org/mydomain/myscan/view/Camera.kt b/app/src/main/java/org/mydomain/myscan/view/Camera.kt index e0be20a..2b2c4c1 100644 --- a/app/src/main/java/org/mydomain/myscan/view/Camera.kt +++ b/app/src/main/java/org/mydomain/myscan/view/Camera.kt @@ -131,6 +131,7 @@ fun CameraScreen( pageCount = viewModel.pageCount(), liveAnalysisState, onCapture = { + viewModel.liveAnalysisEnabled = false showPageDialog.value = true isProcessing.value = true captureController.takePicture( @@ -138,6 +139,7 @@ fun CameraScreen( if (imageProxy != null) { viewModel.processCapturedImageThen(imageProxy) { isProcessing.value = false + viewModel.liveAnalysisEnabled = true } } else { Log.e("MyScan", "Error during image capture")