ImageSegmentation: small refactorings

This commit is contained in:
Pierre-Yves Nicolas
2025-06-01 06:58:51 +02:00
parent f8b9f47782
commit 73b0d47796

View File

@@ -9,9 +9,7 @@ import android.os.SystemClock
import android.util.Log import android.util.Log
import androidx.core.graphics.scale import androidx.core.graphics.scale
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.isActive import kotlinx.coroutines.isActive
@@ -24,7 +22,6 @@ import org.tensorflow.lite.support.image.TensorImage
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.FloatBuffer import java.nio.FloatBuffer
// TODO Review and remove unneeded code
class ImageSegmentationService(private val context: Context) { class ImageSegmentationService(private val context: Context) {
companion object { companion object {
@@ -34,13 +31,9 @@ class ImageSegmentationService(private val context: Context) {
private val _segmentation = MutableStateFlow<SegmentationResult?>(null) private val _segmentation = MutableStateFlow<SegmentationResult?>(null)
val segmentation: StateFlow<SegmentationResult?> = _segmentation.asStateFlow() val segmentation: StateFlow<SegmentationResult?> = _segmentation.asStateFlow()
val error: SharedFlow<Throwable?>
get() = _error
private val _error = MutableSharedFlow<Throwable?>()
private var interpreter: Interpreter? = null private var interpreter: Interpreter? = null
suspend fun initialize() { fun initialize() {
interpreter = try { interpreter = try {
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite") val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
Log.i(TAG, "Loaded LiteRT model") Log.i(TAG, "Loaded LiteRT model")
@@ -49,8 +42,7 @@ class ImageSegmentationService(private val context: Context) {
} }
Interpreter(litertBuffer, options) Interpreter(litertBuffer, options)
} catch (e: Exception) { } catch (e: Exception) {
Log.i(TAG, "Failed to load LiteRT model: ${e.message}") Log.e(TAG, "Failed to load LiteRT model: ${e.message}")
_error.emit(e)
null null
} }
} }
@@ -91,8 +83,7 @@ class ImageSegmentationService(private val context: Context) {
} }
} }
} catch (e: Exception) { } catch (e: Exception) {
Log.i(TAG, "Image segmentation error occurred: ${e.message}") Log.e(TAG, "Error occurred in image segmentation: ${e.message}")
_error.emit(e)
} }
} }
@@ -139,7 +130,7 @@ class ImageSegmentationService(private val context: Context) {
outputBuffer.rewind() outputBuffer.rewind()
val inferenceData = val inferenceData =
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer) InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
val mask = processImage(inferenceData) val mask = generateMaskFromOutputBuffer(inferenceData)
val imageProperties = val imageProperties =
ImageProperties ImageProperties
@@ -153,26 +144,26 @@ class ImageSegmentationService(private val context: Context) {
return Segmentation(maskImage) return Segmentation(maskImage)
} }
private fun processImage(inferenceData: InferenceData): ByteBuffer { private fun generateMaskFromOutputBuffer(inferenceData: InferenceData): ByteBuffer {
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height) val width = inferenceData.width
for (i in 0 until inferenceData.height) { val height = inferenceData.height
for (j in 0 until inferenceData.width) { val mask = ByteBuffer.allocateDirect(width * height)
for (i in 0 until height) {
for (j in 0 until width) {
var maxIndex = 0 var maxIndex = 0
var maxValue = inferenceData.buffer.get(i * inferenceData.width + j) var maxValue = inferenceData.buffer[i * width + j]
for (c in 1 until inferenceData.channels) { for (c in 1 until inferenceData.channels) {
val value = inferenceData.buffer.get( val value = inferenceData.buffer[c * height * width + i * width + j]
c * inferenceData.height * inferenceData.width + i * inferenceData.width + j)
if (value > maxValue) { if (value > maxValue) {
maxValue = value maxValue = value
maxIndex = c maxIndex = c
} }
} }
mask.put(i * inferenceData.width + j, maxIndex.toByte()) mask.put(i * width + j, maxIndex.toByte())
} }
} }
return mask return mask
} }