New segmentation model

This commit is contained in:
Pierre-Yves Nicolas
2025-06-19 15:00:20 +02:00
committed by pynicolas
parent e3ada11c8c
commit 0c5a219783
3 changed files with 52 additions and 111 deletions

View File

@@ -18,23 +18,24 @@ import android.content.Context
import android.graphics.Bitmap
import android.graphics.Bitmap.createBitmap
import android.graphics.Color
import android.graphics.Matrix
import android.os.SystemClock
import android.util.Log
import androidx.core.graphics.scale
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.isActive
import kotlinx.coroutines.withContext
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.image.ColorSpaceType
import org.tensorflow.lite.support.image.ImageProperties
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.Rot90Op
import java.nio.ByteBuffer
import java.nio.FloatBuffer
import java.nio.ByteOrder
class ImageSegmentationService(private val context: Context) {
@@ -49,7 +50,7 @@ class ImageSegmentationService(private val context: Context) {
fun initialize() {
interpreter = try {
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
val litertBuffer = FileUtil.loadMappedFile(context, "timm_efficientnet_lite0_quantized.tflite")
Log.i(TAG, "Loaded LiteRT model")
val options = Interpreter.Options().apply {
numThreads = 2
@@ -64,19 +65,21 @@ class ImageSegmentationService(private val context: Context) {
private fun runSegmentation(interpreter: Interpreter, bitmap: Bitmap, rotationDegrees: Int): SegmentationResult {
val startTime = SystemClock.uptimeMillis()
val (_, _, h, w) = interpreter.getInputTensor(0).shape()
// Preprocess manually into CHW float buffer
val inputBuffer = bitmapToCHWFloatBuffer(bitmap, width = w, height = h, rotationDegrees)
val (_, cOut, hOut, wOut) = interpreter.getOutputTensor(0).shape()
val outputBuffer = FloatBuffer.allocate(cOut * hOut * wOut)
// Run inference
outputBuffer.rewind()
interpreter.run(inputBuffer, outputBuffer)
val rotation = -rotationDegrees / 90
val (_, h, w, _) = interpreter.getOutputTensor(0).shape()
val imageProcessor =
ImageProcessor
.Builder()
.add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR))
.add(Rot90Op(rotation))
.add(NormalizeOp(127.5f, 127.5f)) // TODO check if it's correct
.build()
val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(bitmap)
val processedImage = imageProcessor.process(tensorImage)
val segmentResult = segment(interpreter, processedImage)
val inferenceTime = SystemClock.uptimeMillis() - startTime
val segmentResult = processOutputBuffer(outputBuffer, wOut, hOut, cOut)
return SegmentationResult(segmentResult, inferenceTime)
}
@@ -101,111 +104,40 @@ class ImageSegmentationService(private val context: Context) {
}
}
fun bitmapToCHWFloatBuffer(bitmap: Bitmap, width: Int, height: Int, rotationDegrees: Int): FloatBuffer {
val rotatedBitmap = if (rotationDegrees != 0) {
val matrix = Matrix().apply { postRotate(rotationDegrees.toFloat()) }
createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
} else {
bitmap
}
private fun segment(interpreter: Interpreter, tensorImage: TensorImage): Segmentation {
val (_, h, w, _) = interpreter.getOutputTensor(0).shape()
val outputBuffer = ByteBuffer.allocateDirect(4 * h * w)
outputBuffer.order(ByteOrder.nativeOrder())
outputBuffer.rewind()
interpreter.run(tensorImage.tensorBuffer.buffer, outputBuffer)
outputBuffer.rewind()
val mask = generateMaskFromOutputBuffer(outputBuffer, w, h)
return Segmentation(mask)
}
val resized = rotatedBitmap.scale(width, height)
val buffer = FloatBuffer.allocate(1 * 3 * height * width)
buffer.rewind()
val mean = floatArrayOf(0.4611f, 0.4359f, 0.3905f)
val std = floatArrayOf(0.2193f, 0.2150f, 0.2109f)
private fun generateMaskFromOutputBuffer(outputBuffer: ByteBuffer, width: Int, height: Int): Bitmap {
outputBuffer.rewind()
val floatArray = FloatArray(width * height)
outputBuffer.asFloatBuffer()[floatArray]
val pixels = IntArray(width * height)
resized.getPixels(pixels, 0, width, 0, 0, width, height)
// Fill buffer in CHW order
for (c in 0..2) {
for (i in 0 until height) {
for (j in 0 until width) {
val pixel = pixels[i * width + j]
val value = when (c) {
0 -> (pixel shr 16 and 0xFF) // R
1 -> (pixel shr 8 and 0xFF) // G
2 -> (pixel and 0xFF) // B
else -> 0
}
val normalized = (value / 255f - mean[c]) / std[c]
buffer.put(normalized)
}
}
for (i in floatArray.indices) {
val value = floatArray[i].coerceIn(0f, 1f)
val gray = (value * 255).toInt()
pixels[i] = Color.rgb(gray, gray, gray)
}
buffer.rewind()
return buffer
val bitmap = createBitmap(width, height, Bitmap.Config.ARGB_8888)
bitmap.setPixels(pixels, 0, width, 0, 0, width, height)
return bitmap
}
private fun processOutputBuffer(outputBuffer: FloatBuffer, w: Int, h: Int, c: Int): Segmentation {
outputBuffer.rewind()
val inferenceData =
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
val mask = generateMaskFromOutputBuffer(inferenceData)
val imageProperties =
ImageProperties
.builder()
.setWidth(inferenceData.width)
.setHeight(inferenceData.height)
.setColorSpaceType(ColorSpaceType.GRAYSCALE)
.build()
val maskImage = TensorImage()
maskImage.load(mask, imageProperties)
return Segmentation(maskImage)
}
private fun generateMaskFromOutputBuffer(inferenceData: InferenceData): ByteBuffer {
val width = inferenceData.width
val height = inferenceData.height
val mask = ByteBuffer.allocateDirect(width * height)
for (i in 0 until height) {
for (j in 0 until width) {
var maxIndex = 0
var maxValue = inferenceData.buffer[i * width + j]
for (c in 1 until inferenceData.channels) {
val value = inferenceData.buffer[c * height * width + i * width + j]
if (value > maxValue) {
maxValue = value
maxIndex = c
}
}
mask.put(i * width + j, maxIndex.toByte())
}
}
return mask
}
data class Segmentation(val mask: TensorImage) {
fun toBinaryMask(): Bitmap {
val width = mask.width
val height = mask.height
val pixels = IntArray(width * height)
for (i in 0 until height) {
for (j in 0 until width) {
val index = i * width + j
val classId = mask.buffer[index].toInt() and 0xFF // Unsigned byte
pixels[index] = if (classId == 0) Color.BLACK else Color.WHITE
}
}
return createBitmap(pixels, width, height, Bitmap.Config.ARGB_8888)
}
data class Segmentation(val mask: Bitmap) {
fun toBinaryMask(): Bitmap = mask
}
data class SegmentationResult(
val segmentation: Segmentation,
val inferenceTime: Long
)
data class InferenceData(
val width: Int,
val height: Int,
val channels: Int,
val buffer: FloatBuffer,
)
}

View File

@@ -60,6 +60,8 @@ class MainViewModel(
private var _pageToValidate = MutableStateFlow<Bitmap?>(null)
val pageToValidate: StateFlow<Bitmap?> = _pageToValidate.asStateFlow()
var liveAnalysisEnabled = true
init {
viewModelScope.launch {
imageSegmentationService.initialize()
@@ -80,6 +82,11 @@ class MainViewModel(
}
fun segment(imageProxy: ImageProxy) {
if (!liveAnalysisEnabled) {
imageProxy.close()
return
}
viewModelScope.launch {
imageSegmentationService.runSegmentationAndEmit(
imageProxy.toBitmap(),

View File

@@ -131,6 +131,7 @@ fun CameraScreen(
pageCount = viewModel.pageCount(),
liveAnalysisState,
onCapture = {
viewModel.liveAnalysisEnabled = false
showPageDialog.value = true
isProcessing.value = true
captureController.takePicture(
@@ -138,6 +139,7 @@ fun CameraScreen(
if (imageProxy != null) {
viewModel.processCapturedImageThen(imageProxy) {
isProcessing.value = false
viewModel.liveAnalysisEnabled = true
}
} else {
Log.e("MyScan", "Error during image capture")