New segmentation model
This commit is contained in:
committed by
pynicolas
parent
e3ada11c8c
commit
0c5a219783
@@ -18,23 +18,24 @@ import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.Bitmap.createBitmap
|
||||
import android.graphics.Color
|
||||
import android.graphics.Matrix
|
||||
import android.os.SystemClock
|
||||
import android.util.Log
|
||||
import androidx.core.graphics.scale
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.tensorflow.lite.DataType
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import org.tensorflow.lite.support.common.FileUtil
|
||||
import org.tensorflow.lite.support.image.ColorSpaceType
|
||||
import org.tensorflow.lite.support.image.ImageProperties
|
||||
import org.tensorflow.lite.support.common.ops.NormalizeOp
|
||||
import org.tensorflow.lite.support.image.ImageProcessor
|
||||
import org.tensorflow.lite.support.image.TensorImage
|
||||
import org.tensorflow.lite.support.image.ops.ResizeOp
|
||||
import org.tensorflow.lite.support.image.ops.Rot90Op
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.FloatBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
class ImageSegmentationService(private val context: Context) {
|
||||
|
||||
@@ -49,7 +50,7 @@ class ImageSegmentationService(private val context: Context) {
|
||||
|
||||
fun initialize() {
|
||||
interpreter = try {
|
||||
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
|
||||
val litertBuffer = FileUtil.loadMappedFile(context, "timm_efficientnet_lite0_quantized.tflite")
|
||||
Log.i(TAG, "Loaded LiteRT model")
|
||||
val options = Interpreter.Options().apply {
|
||||
numThreads = 2
|
||||
@@ -64,19 +65,21 @@ class ImageSegmentationService(private val context: Context) {
|
||||
private fun runSegmentation(interpreter: Interpreter, bitmap: Bitmap, rotationDegrees: Int): SegmentationResult {
|
||||
val startTime = SystemClock.uptimeMillis()
|
||||
|
||||
val (_, _, h, w) = interpreter.getInputTensor(0).shape()
|
||||
// Preprocess manually into CHW float buffer
|
||||
val inputBuffer = bitmapToCHWFloatBuffer(bitmap, width = w, height = h, rotationDegrees)
|
||||
|
||||
val (_, cOut, hOut, wOut) = interpreter.getOutputTensor(0).shape()
|
||||
val outputBuffer = FloatBuffer.allocate(cOut * hOut * wOut)
|
||||
|
||||
// Run inference
|
||||
outputBuffer.rewind()
|
||||
interpreter.run(inputBuffer, outputBuffer)
|
||||
val rotation = -rotationDegrees / 90
|
||||
val (_, h, w, _) = interpreter.getOutputTensor(0).shape()
|
||||
val imageProcessor =
|
||||
ImageProcessor
|
||||
.Builder()
|
||||
.add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR))
|
||||
.add(Rot90Op(rotation))
|
||||
.add(NormalizeOp(127.5f, 127.5f)) // TODO check if it's correct
|
||||
.build()
|
||||
val tensorImage = TensorImage(DataType.FLOAT32)
|
||||
tensorImage.load(bitmap)
|
||||
val processedImage = imageProcessor.process(tensorImage)
|
||||
val segmentResult = segment(interpreter, processedImage)
|
||||
|
||||
val inferenceTime = SystemClock.uptimeMillis() - startTime
|
||||
val segmentResult = processOutputBuffer(outputBuffer, wOut, hOut, cOut)
|
||||
return SegmentationResult(segmentResult, inferenceTime)
|
||||
}
|
||||
|
||||
@@ -101,111 +104,40 @@ class ImageSegmentationService(private val context: Context) {
|
||||
}
|
||||
}
|
||||
|
||||
fun bitmapToCHWFloatBuffer(bitmap: Bitmap, width: Int, height: Int, rotationDegrees: Int): FloatBuffer {
|
||||
val rotatedBitmap = if (rotationDegrees != 0) {
|
||||
val matrix = Matrix().apply { postRotate(rotationDegrees.toFloat()) }
|
||||
createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
|
||||
} else {
|
||||
bitmap
|
||||
}
|
||||
private fun segment(interpreter: Interpreter, tensorImage: TensorImage): Segmentation {
|
||||
val (_, h, w, _) = interpreter.getOutputTensor(0).shape()
|
||||
val outputBuffer = ByteBuffer.allocateDirect(4 * h * w)
|
||||
outputBuffer.order(ByteOrder.nativeOrder())
|
||||
outputBuffer.rewind()
|
||||
interpreter.run(tensorImage.tensorBuffer.buffer, outputBuffer)
|
||||
outputBuffer.rewind()
|
||||
val mask = generateMaskFromOutputBuffer(outputBuffer, w, h)
|
||||
return Segmentation(mask)
|
||||
}
|
||||
|
||||
val resized = rotatedBitmap.scale(width, height)
|
||||
val buffer = FloatBuffer.allocate(1 * 3 * height * width)
|
||||
buffer.rewind()
|
||||
|
||||
val mean = floatArrayOf(0.4611f, 0.4359f, 0.3905f)
|
||||
val std = floatArrayOf(0.2193f, 0.2150f, 0.2109f)
|
||||
private fun generateMaskFromOutputBuffer(outputBuffer: ByteBuffer, width: Int, height: Int): Bitmap {
|
||||
outputBuffer.rewind()
|
||||
val floatArray = FloatArray(width * height)
|
||||
outputBuffer.asFloatBuffer()[floatArray]
|
||||
|
||||
val pixels = IntArray(width * height)
|
||||
resized.getPixels(pixels, 0, width, 0, 0, width, height)
|
||||
|
||||
// Fill buffer in CHW order
|
||||
for (c in 0..2) {
|
||||
for (i in 0 until height) {
|
||||
for (j in 0 until width) {
|
||||
val pixel = pixels[i * width + j]
|
||||
val value = when (c) {
|
||||
0 -> (pixel shr 16 and 0xFF) // R
|
||||
1 -> (pixel shr 8 and 0xFF) // G
|
||||
2 -> (pixel and 0xFF) // B
|
||||
else -> 0
|
||||
}
|
||||
val normalized = (value / 255f - mean[c]) / std[c]
|
||||
buffer.put(normalized)
|
||||
}
|
||||
}
|
||||
for (i in floatArray.indices) {
|
||||
val value = floatArray[i].coerceIn(0f, 1f)
|
||||
val gray = (value * 255).toInt()
|
||||
pixels[i] = Color.rgb(gray, gray, gray)
|
||||
}
|
||||
|
||||
buffer.rewind()
|
||||
return buffer
|
||||
val bitmap = createBitmap(width, height, Bitmap.Config.ARGB_8888)
|
||||
bitmap.setPixels(pixels, 0, width, 0, 0, width, height)
|
||||
return bitmap
|
||||
}
|
||||
|
||||
private fun processOutputBuffer(outputBuffer: FloatBuffer, w: Int, h: Int, c: Int): Segmentation {
|
||||
outputBuffer.rewind()
|
||||
val inferenceData =
|
||||
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
|
||||
val mask = generateMaskFromOutputBuffer(inferenceData)
|
||||
|
||||
val imageProperties =
|
||||
ImageProperties
|
||||
.builder()
|
||||
.setWidth(inferenceData.width)
|
||||
.setHeight(inferenceData.height)
|
||||
.setColorSpaceType(ColorSpaceType.GRAYSCALE)
|
||||
.build()
|
||||
val maskImage = TensorImage()
|
||||
maskImage.load(mask, imageProperties)
|
||||
return Segmentation(maskImage)
|
||||
}
|
||||
|
||||
private fun generateMaskFromOutputBuffer(inferenceData: InferenceData): ByteBuffer {
|
||||
val width = inferenceData.width
|
||||
val height = inferenceData.height
|
||||
val mask = ByteBuffer.allocateDirect(width * height)
|
||||
for (i in 0 until height) {
|
||||
for (j in 0 until width) {
|
||||
var maxIndex = 0
|
||||
var maxValue = inferenceData.buffer[i * width + j]
|
||||
|
||||
for (c in 1 until inferenceData.channels) {
|
||||
val value = inferenceData.buffer[c * height * width + i * width + j]
|
||||
if (value > maxValue) {
|
||||
maxValue = value
|
||||
maxIndex = c
|
||||
}
|
||||
}
|
||||
|
||||
mask.put(i * width + j, maxIndex.toByte())
|
||||
}
|
||||
}
|
||||
return mask
|
||||
}
|
||||
|
||||
data class Segmentation(val mask: TensorImage) {
|
||||
fun toBinaryMask(): Bitmap {
|
||||
val width = mask.width
|
||||
val height = mask.height
|
||||
val pixels = IntArray(width * height)
|
||||
for (i in 0 until height) {
|
||||
for (j in 0 until width) {
|
||||
val index = i * width + j
|
||||
val classId = mask.buffer[index].toInt() and 0xFF // Unsigned byte
|
||||
pixels[index] = if (classId == 0) Color.BLACK else Color.WHITE
|
||||
}
|
||||
}
|
||||
return createBitmap(pixels, width, height, Bitmap.Config.ARGB_8888)
|
||||
}
|
||||
data class Segmentation(val mask: Bitmap) {
|
||||
fun toBinaryMask(): Bitmap = mask
|
||||
}
|
||||
|
||||
data class SegmentationResult(
|
||||
val segmentation: Segmentation,
|
||||
val inferenceTime: Long
|
||||
)
|
||||
|
||||
data class InferenceData(
|
||||
val width: Int,
|
||||
val height: Int,
|
||||
val channels: Int,
|
||||
val buffer: FloatBuffer,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -60,6 +60,8 @@ class MainViewModel(
|
||||
private var _pageToValidate = MutableStateFlow<Bitmap?>(null)
|
||||
val pageToValidate: StateFlow<Bitmap?> = _pageToValidate.asStateFlow()
|
||||
|
||||
var liveAnalysisEnabled = true
|
||||
|
||||
init {
|
||||
viewModelScope.launch {
|
||||
imageSegmentationService.initialize()
|
||||
@@ -80,6 +82,11 @@ class MainViewModel(
|
||||
}
|
||||
|
||||
fun segment(imageProxy: ImageProxy) {
|
||||
if (!liveAnalysisEnabled) {
|
||||
imageProxy.close()
|
||||
return
|
||||
}
|
||||
|
||||
viewModelScope.launch {
|
||||
imageSegmentationService.runSegmentationAndEmit(
|
||||
imageProxy.toBitmap(),
|
||||
|
||||
@@ -131,6 +131,7 @@ fun CameraScreen(
|
||||
pageCount = viewModel.pageCount(),
|
||||
liveAnalysisState,
|
||||
onCapture = {
|
||||
viewModel.liveAnalysisEnabled = false
|
||||
showPageDialog.value = true
|
||||
isProcessing.value = true
|
||||
captureController.takePicture(
|
||||
@@ -138,6 +139,7 @@ fun CameraScreen(
|
||||
if (imageProxy != null) {
|
||||
viewModel.processCapturedImageThen(imageProxy) {
|
||||
isProcessing.value = false
|
||||
viewModel.liveAnalysisEnabled = true
|
||||
}
|
||||
} else {
|
||||
Log.e("MyScan", "Error during image capture")
|
||||
|
||||
Reference in New Issue
Block a user