ImageSegmentation: small refactorings
This commit is contained in:
@@ -9,9 +9,7 @@ import android.os.SystemClock
|
||||
import android.util.Log
|
||||
import androidx.core.graphics.scale
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.SharedFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.isActive
|
||||
@@ -24,7 +22,6 @@ import org.tensorflow.lite.support.image.TensorImage
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.FloatBuffer
|
||||
|
||||
// TODO Review and remove unneeded code
|
||||
class ImageSegmentationService(private val context: Context) {
|
||||
|
||||
companion object {
|
||||
@@ -34,13 +31,9 @@ class ImageSegmentationService(private val context: Context) {
|
||||
private val _segmentation = MutableStateFlow<SegmentationResult?>(null)
|
||||
val segmentation: StateFlow<SegmentationResult?> = _segmentation.asStateFlow()
|
||||
|
||||
val error: SharedFlow<Throwable?>
|
||||
get() = _error
|
||||
private val _error = MutableSharedFlow<Throwable?>()
|
||||
|
||||
private var interpreter: Interpreter? = null
|
||||
|
||||
suspend fun initialize() {
|
||||
fun initialize() {
|
||||
interpreter = try {
|
||||
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
|
||||
Log.i(TAG, "Loaded LiteRT model")
|
||||
@@ -49,8 +42,7 @@ class ImageSegmentationService(private val context: Context) {
|
||||
}
|
||||
Interpreter(litertBuffer, options)
|
||||
} catch (e: Exception) {
|
||||
Log.i(TAG, "Failed to load LiteRT model: ${e.message}")
|
||||
_error.emit(e)
|
||||
Log.e(TAG, "Failed to load LiteRT model: ${e.message}")
|
||||
null
|
||||
}
|
||||
}
|
||||
@@ -91,8 +83,7 @@ class ImageSegmentationService(private val context: Context) {
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.i(TAG, "Image segmentation error occurred: ${e.message}")
|
||||
_error.emit(e)
|
||||
Log.e(TAG, "Error occurred in image segmentation: ${e.message}")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,7 +130,7 @@ class ImageSegmentationService(private val context: Context) {
|
||||
outputBuffer.rewind()
|
||||
val inferenceData =
|
||||
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
|
||||
val mask = processImage(inferenceData)
|
||||
val mask = generateMaskFromOutputBuffer(inferenceData)
|
||||
|
||||
val imageProperties =
|
||||
ImageProperties
|
||||
@@ -153,26 +144,26 @@ class ImageSegmentationService(private val context: Context) {
|
||||
return Segmentation(maskImage)
|
||||
}
|
||||
|
||||
private fun processImage(inferenceData: InferenceData): ByteBuffer {
|
||||
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height)
|
||||
for (i in 0 until inferenceData.height) {
|
||||
for (j in 0 until inferenceData.width) {
|
||||
private fun generateMaskFromOutputBuffer(inferenceData: InferenceData): ByteBuffer {
|
||||
val width = inferenceData.width
|
||||
val height = inferenceData.height
|
||||
val mask = ByteBuffer.allocateDirect(width * height)
|
||||
for (i in 0 until height) {
|
||||
for (j in 0 until width) {
|
||||
var maxIndex = 0
|
||||
var maxValue = inferenceData.buffer.get(i * inferenceData.width + j)
|
||||
var maxValue = inferenceData.buffer[i * width + j]
|
||||
|
||||
for (c in 1 until inferenceData.channels) {
|
||||
val value = inferenceData.buffer.get(
|
||||
c * inferenceData.height * inferenceData.width + i * inferenceData.width + j)
|
||||
val value = inferenceData.buffer[c * height * width + i * width + j]
|
||||
if (value > maxValue) {
|
||||
maxValue = value
|
||||
maxIndex = c
|
||||
}
|
||||
}
|
||||
|
||||
mask.put(i * inferenceData.width + j, maxIndex.toByte())
|
||||
mask.put(i * width + j, maxIndex.toByte())
|
||||
}
|
||||
}
|
||||
|
||||
return mask
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user