ImageSegmentation: small refactorings
This commit is contained in:
@@ -9,9 +9,7 @@ import android.os.SystemClock
|
|||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.core.graphics.scale
|
import androidx.core.graphics.scale
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.SharedFlow
|
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.isActive
|
import kotlinx.coroutines.isActive
|
||||||
@@ -24,7 +22,6 @@ import org.tensorflow.lite.support.image.TensorImage
|
|||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
import java.nio.FloatBuffer
|
import java.nio.FloatBuffer
|
||||||
|
|
||||||
// TODO Review and remove unneeded code
|
|
||||||
class ImageSegmentationService(private val context: Context) {
|
class ImageSegmentationService(private val context: Context) {
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
@@ -34,13 +31,9 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
private val _segmentation = MutableStateFlow<SegmentationResult?>(null)
|
private val _segmentation = MutableStateFlow<SegmentationResult?>(null)
|
||||||
val segmentation: StateFlow<SegmentationResult?> = _segmentation.asStateFlow()
|
val segmentation: StateFlow<SegmentationResult?> = _segmentation.asStateFlow()
|
||||||
|
|
||||||
val error: SharedFlow<Throwable?>
|
|
||||||
get() = _error
|
|
||||||
private val _error = MutableSharedFlow<Throwable?>()
|
|
||||||
|
|
||||||
private var interpreter: Interpreter? = null
|
private var interpreter: Interpreter? = null
|
||||||
|
|
||||||
suspend fun initialize() {
|
fun initialize() {
|
||||||
interpreter = try {
|
interpreter = try {
|
||||||
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
|
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
|
||||||
Log.i(TAG, "Loaded LiteRT model")
|
Log.i(TAG, "Loaded LiteRT model")
|
||||||
@@ -49,8 +42,7 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
}
|
}
|
||||||
Interpreter(litertBuffer, options)
|
Interpreter(litertBuffer, options)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.i(TAG, "Failed to load LiteRT model: ${e.message}")
|
Log.e(TAG, "Failed to load LiteRT model: ${e.message}")
|
||||||
_error.emit(e)
|
|
||||||
null
|
null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -91,8 +83,7 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.i(TAG, "Image segmentation error occurred: ${e.message}")
|
Log.e(TAG, "Error occurred in image segmentation: ${e.message}")
|
||||||
_error.emit(e)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +130,7 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
outputBuffer.rewind()
|
outputBuffer.rewind()
|
||||||
val inferenceData =
|
val inferenceData =
|
||||||
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
|
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
|
||||||
val mask = processImage(inferenceData)
|
val mask = generateMaskFromOutputBuffer(inferenceData)
|
||||||
|
|
||||||
val imageProperties =
|
val imageProperties =
|
||||||
ImageProperties
|
ImageProperties
|
||||||
@@ -153,26 +144,26 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
return Segmentation(maskImage)
|
return Segmentation(maskImage)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun processImage(inferenceData: InferenceData): ByteBuffer {
|
private fun generateMaskFromOutputBuffer(inferenceData: InferenceData): ByteBuffer {
|
||||||
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height)
|
val width = inferenceData.width
|
||||||
for (i in 0 until inferenceData.height) {
|
val height = inferenceData.height
|
||||||
for (j in 0 until inferenceData.width) {
|
val mask = ByteBuffer.allocateDirect(width * height)
|
||||||
|
for (i in 0 until height) {
|
||||||
|
for (j in 0 until width) {
|
||||||
var maxIndex = 0
|
var maxIndex = 0
|
||||||
var maxValue = inferenceData.buffer.get(i * inferenceData.width + j)
|
var maxValue = inferenceData.buffer[i * width + j]
|
||||||
|
|
||||||
for (c in 1 until inferenceData.channels) {
|
for (c in 1 until inferenceData.channels) {
|
||||||
val value = inferenceData.buffer.get(
|
val value = inferenceData.buffer[c * height * width + i * width + j]
|
||||||
c * inferenceData.height * inferenceData.width + i * inferenceData.width + j)
|
|
||||||
if (value > maxValue) {
|
if (value > maxValue) {
|
||||||
maxValue = value
|
maxValue = value
|
||||||
maxIndex = c
|
maxIndex = c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mask.put(i * inferenceData.width + j, maxIndex.toByte())
|
mask.put(i * width + j, maxIndex.toByte())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return mask
|
return mask
|
||||||
}
|
}
|
||||||
|
|
||||||
Reference in New Issue
Block a user