ImageSegmentation: small refactorings

This commit is contained in:
Pierre-Yves Nicolas
2025-06-01 06:58:51 +02:00
parent f8b9f47782
commit 73b0d47796

View File

@@ -9,9 +9,7 @@ import android.os.SystemClock
import android.util.Log
import androidx.core.graphics.scale
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.isActive
@@ -24,7 +22,6 @@ import org.tensorflow.lite.support.image.TensorImage
import java.nio.ByteBuffer
import java.nio.FloatBuffer
// TODO Review and remove unneeded code
class ImageSegmentationService(private val context: Context) {
companion object {
@@ -34,13 +31,9 @@ class ImageSegmentationService(private val context: Context) {
private val _segmentation = MutableStateFlow<SegmentationResult?>(null)
val segmentation: StateFlow<SegmentationResult?> = _segmentation.asStateFlow()
val error: SharedFlow<Throwable?>
get() = _error
private val _error = MutableSharedFlow<Throwable?>()
private var interpreter: Interpreter? = null
suspend fun initialize() {
fun initialize() {
interpreter = try {
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
Log.i(TAG, "Loaded LiteRT model")
@@ -49,8 +42,7 @@ class ImageSegmentationService(private val context: Context) {
}
Interpreter(litertBuffer, options)
} catch (e: Exception) {
Log.i(TAG, "Failed to load LiteRT model: ${e.message}")
_error.emit(e)
Log.e(TAG, "Failed to load LiteRT model: ${e.message}")
null
}
}
@@ -91,8 +83,7 @@ class ImageSegmentationService(private val context: Context) {
}
}
} catch (e: Exception) {
Log.i(TAG, "Image segmentation error occurred: ${e.message}")
_error.emit(e)
Log.e(TAG, "Error occurred in image segmentation: ${e.message}")
}
}
@@ -139,7 +130,7 @@ class ImageSegmentationService(private val context: Context) {
outputBuffer.rewind()
val inferenceData =
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
val mask = processImage(inferenceData)
val mask = generateMaskFromOutputBuffer(inferenceData)
val imageProperties =
ImageProperties
@@ -153,26 +144,26 @@ class ImageSegmentationService(private val context: Context) {
return Segmentation(maskImage)
}
private fun processImage(inferenceData: InferenceData): ByteBuffer {
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height)
for (i in 0 until inferenceData.height) {
for (j in 0 until inferenceData.width) {
private fun generateMaskFromOutputBuffer(inferenceData: InferenceData): ByteBuffer {
val width = inferenceData.width
val height = inferenceData.height
val mask = ByteBuffer.allocateDirect(width * height)
for (i in 0 until height) {
for (j in 0 until width) {
var maxIndex = 0
var maxValue = inferenceData.buffer.get(i * inferenceData.width + j)
var maxValue = inferenceData.buffer[i * width + j]
for (c in 1 until inferenceData.channels) {
val value = inferenceData.buffer.get(
c * inferenceData.height * inferenceData.width + i * inferenceData.width + j)
val value = inferenceData.buffer[c * height * width + i * width + j]
if (value > maxValue) {
maxValue = value
maxIndex = c
}
}
mask.put(i * inferenceData.width + j, maxIndex.toByte())
mask.put(i * width + j, maxIndex.toByte())
}
}
return mask
}