From 73b0d47796e67d843772850fe6bb5f049a2edcb2 Mon Sep 17 00:00:00 2001 From: Pierre-Yves Nicolas <6371790+pynicolas@users.noreply.github.com> Date: Sun, 1 Jun 2025 06:58:51 +0200 Subject: [PATCH] ImageSegmentation: small refactorings --- ...ntationService.kt => ImageSegmentation.kt} | 35 +++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) rename app/src/main/java/org/mydomain/myscan/{ImageSegmentationService.kt => ImageSegmentation.kt} (85%) diff --git a/app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt b/app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt similarity index 85% rename from app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt rename to app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt index e96c52a..86f89b6 100644 --- a/app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt +++ b/app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt @@ -9,9 +9,7 @@ import android.os.SystemClock import android.util.Log import androidx.core.graphics.scale import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.SharedFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.isActive @@ -24,7 +22,6 @@ import org.tensorflow.lite.support.image.TensorImage import java.nio.ByteBuffer import java.nio.FloatBuffer -// TODO Review and remove unneeded code class ImageSegmentationService(private val context: Context) { companion object { @@ -34,13 +31,9 @@ class ImageSegmentationService(private val context: Context) { private val _segmentation = MutableStateFlow(null) val segmentation: StateFlow = _segmentation.asStateFlow() - val error: SharedFlow - get() = _error - private val _error = MutableSharedFlow() - private var interpreter: Interpreter? = null - suspend fun initialize() { + fun initialize() { interpreter = try { val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite") Log.i(TAG, "Loaded LiteRT model") @@ -49,8 +42,7 @@ class ImageSegmentationService(private val context: Context) { } Interpreter(litertBuffer, options) } catch (e: Exception) { - Log.i(TAG, "Failed to load LiteRT model: ${e.message}") - _error.emit(e) + Log.e(TAG, "Failed to load LiteRT model: ${e.message}") null } } @@ -91,8 +83,7 @@ class ImageSegmentationService(private val context: Context) { } } } catch (e: Exception) { - Log.i(TAG, "Image segmentation error occurred: ${e.message}") - _error.emit(e) + Log.e(TAG, "Error occurred in image segmentation: ${e.message}") } } @@ -139,7 +130,7 @@ class ImageSegmentationService(private val context: Context) { outputBuffer.rewind() val inferenceData = InferenceData(width = w, height = h, channels = c, buffer = outputBuffer) - val mask = processImage(inferenceData) + val mask = generateMaskFromOutputBuffer(inferenceData) val imageProperties = ImageProperties @@ -153,26 +144,26 @@ class ImageSegmentationService(private val context: Context) { return Segmentation(maskImage) } - private fun processImage(inferenceData: InferenceData): ByteBuffer { - val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height) - for (i in 0 until inferenceData.height) { - for (j in 0 until inferenceData.width) { + private fun generateMaskFromOutputBuffer(inferenceData: InferenceData): ByteBuffer { + val width = inferenceData.width + val height = inferenceData.height + val mask = ByteBuffer.allocateDirect(width * height) + for (i in 0 until height) { + for (j in 0 until width) { var maxIndex = 0 - var maxValue = inferenceData.buffer.get(i * inferenceData.width + j) + var maxValue = inferenceData.buffer[i * width + j] for (c in 1 until inferenceData.channels) { - val value = inferenceData.buffer.get( - c * inferenceData.height * inferenceData.width + i * inferenceData.width + j) + val value = inferenceData.buffer[c * height * width + i * width + j] if (value > maxValue) { maxValue = value maxIndex = c } } - mask.put(i * inferenceData.width + j, maxIndex.toByte()) + mask.put(i * width + j, maxIndex.toByte()) } } - return mask }