New segmentation model

This commit is contained in:
Pierre-Yves Nicolas
2025-06-19 15:00:20 +02:00
committed by pynicolas
parent e3ada11c8c
commit 0c5a219783
3 changed files with 52 additions and 111 deletions

View File

@@ -18,23 +18,24 @@ import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.graphics.Bitmap.createBitmap import android.graphics.Bitmap.createBitmap
import android.graphics.Color import android.graphics.Color
import android.graphics.Matrix
import android.os.SystemClock import android.os.SystemClock
import android.util.Log import android.util.Log
import androidx.core.graphics.scale
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.isActive import kotlinx.coroutines.isActive
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.FileUtil import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.image.ColorSpaceType import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProperties import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.Rot90Op
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.FloatBuffer import java.nio.ByteOrder
class ImageSegmentationService(private val context: Context) { class ImageSegmentationService(private val context: Context) {
@@ -49,7 +50,7 @@ class ImageSegmentationService(private val context: Context) {
fun initialize() { fun initialize() {
interpreter = try { interpreter = try {
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite") val litertBuffer = FileUtil.loadMappedFile(context, "timm_efficientnet_lite0_quantized.tflite")
Log.i(TAG, "Loaded LiteRT model") Log.i(TAG, "Loaded LiteRT model")
val options = Interpreter.Options().apply { val options = Interpreter.Options().apply {
numThreads = 2 numThreads = 2
@@ -64,19 +65,21 @@ class ImageSegmentationService(private val context: Context) {
private fun runSegmentation(interpreter: Interpreter, bitmap: Bitmap, rotationDegrees: Int): SegmentationResult { private fun runSegmentation(interpreter: Interpreter, bitmap: Bitmap, rotationDegrees: Int): SegmentationResult {
val startTime = SystemClock.uptimeMillis() val startTime = SystemClock.uptimeMillis()
val (_, _, h, w) = interpreter.getInputTensor(0).shape() val rotation = -rotationDegrees / 90
// Preprocess manually into CHW float buffer val (_, h, w, _) = interpreter.getOutputTensor(0).shape()
val inputBuffer = bitmapToCHWFloatBuffer(bitmap, width = w, height = h, rotationDegrees) val imageProcessor =
ImageProcessor
val (_, cOut, hOut, wOut) = interpreter.getOutputTensor(0).shape() .Builder()
val outputBuffer = FloatBuffer.allocate(cOut * hOut * wOut) .add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR))
.add(Rot90Op(rotation))
// Run inference .add(NormalizeOp(127.5f, 127.5f)) // TODO check if it's correct
outputBuffer.rewind() .build()
interpreter.run(inputBuffer, outputBuffer) val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(bitmap)
val processedImage = imageProcessor.process(tensorImage)
val segmentResult = segment(interpreter, processedImage)
val inferenceTime = SystemClock.uptimeMillis() - startTime val inferenceTime = SystemClock.uptimeMillis() - startTime
val segmentResult = processOutputBuffer(outputBuffer, wOut, hOut, cOut)
return SegmentationResult(segmentResult, inferenceTime) return SegmentationResult(segmentResult, inferenceTime)
} }
@@ -101,111 +104,40 @@ class ImageSegmentationService(private val context: Context) {
} }
} }
fun bitmapToCHWFloatBuffer(bitmap: Bitmap, width: Int, height: Int, rotationDegrees: Int): FloatBuffer { private fun segment(interpreter: Interpreter, tensorImage: TensorImage): Segmentation {
val rotatedBitmap = if (rotationDegrees != 0) { val (_, h, w, _) = interpreter.getOutputTensor(0).shape()
val matrix = Matrix().apply { postRotate(rotationDegrees.toFloat()) } val outputBuffer = ByteBuffer.allocateDirect(4 * h * w)
createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true) outputBuffer.order(ByteOrder.nativeOrder())
} else {
bitmap
}
val resized = rotatedBitmap.scale(width, height)
val buffer = FloatBuffer.allocate(1 * 3 * height * width)
buffer.rewind()
val mean = floatArrayOf(0.4611f, 0.4359f, 0.3905f)
val std = floatArrayOf(0.2193f, 0.2150f, 0.2109f)
val pixels = IntArray(width * height)
resized.getPixels(pixels, 0, width, 0, 0, width, height)
// Fill buffer in CHW order
for (c in 0..2) {
for (i in 0 until height) {
for (j in 0 until width) {
val pixel = pixels[i * width + j]
val value = when (c) {
0 -> (pixel shr 16 and 0xFF) // R
1 -> (pixel shr 8 and 0xFF) // G
2 -> (pixel and 0xFF) // B
else -> 0
}
val normalized = (value / 255f - mean[c]) / std[c]
buffer.put(normalized)
}
}
}
buffer.rewind()
return buffer
}
private fun processOutputBuffer(outputBuffer: FloatBuffer, w: Int, h: Int, c: Int): Segmentation {
outputBuffer.rewind() outputBuffer.rewind()
val inferenceData = interpreter.run(tensorImage.tensorBuffer.buffer, outputBuffer)
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer) outputBuffer.rewind()
val mask = generateMaskFromOutputBuffer(inferenceData) val mask = generateMaskFromOutputBuffer(outputBuffer, w, h)
return Segmentation(mask)
val imageProperties =
ImageProperties
.builder()
.setWidth(inferenceData.width)
.setHeight(inferenceData.height)
.setColorSpaceType(ColorSpaceType.GRAYSCALE)
.build()
val maskImage = TensorImage()
maskImage.load(mask, imageProperties)
return Segmentation(maskImage)
} }
private fun generateMaskFromOutputBuffer(inferenceData: InferenceData): ByteBuffer { private fun generateMaskFromOutputBuffer(outputBuffer: ByteBuffer, width: Int, height: Int): Bitmap {
val width = inferenceData.width outputBuffer.rewind()
val height = inferenceData.height val floatArray = FloatArray(width * height)
val mask = ByteBuffer.allocateDirect(width * height) outputBuffer.asFloatBuffer()[floatArray]
for (i in 0 until height) {
for (j in 0 until width) {
var maxIndex = 0
var maxValue = inferenceData.buffer[i * width + j]
for (c in 1 until inferenceData.channels) {
val value = inferenceData.buffer[c * height * width + i * width + j]
if (value > maxValue) {
maxValue = value
maxIndex = c
}
}
mask.put(i * width + j, maxIndex.toByte())
}
}
return mask
}
data class Segmentation(val mask: TensorImage) {
fun toBinaryMask(): Bitmap {
val width = mask.width
val height = mask.height
val pixels = IntArray(width * height) val pixels = IntArray(width * height)
for (i in 0 until height) { for (i in floatArray.indices) {
for (j in 0 until width) { val value = floatArray[i].coerceIn(0f, 1f)
val index = i * width + j val gray = (value * 255).toInt()
val classId = mask.buffer[index].toInt() and 0xFF // Unsigned byte pixels[i] = Color.rgb(gray, gray, gray)
pixels[index] = if (classId == 0) Color.BLACK else Color.WHITE
} }
val bitmap = createBitmap(width, height, Bitmap.Config.ARGB_8888)
bitmap.setPixels(pixels, 0, width, 0, 0, width, height)
return bitmap
} }
return createBitmap(pixels, width, height, Bitmap.Config.ARGB_8888)
} data class Segmentation(val mask: Bitmap) {
fun toBinaryMask(): Bitmap = mask
} }
data class SegmentationResult( data class SegmentationResult(
val segmentation: Segmentation, val segmentation: Segmentation,
val inferenceTime: Long val inferenceTime: Long
) )
data class InferenceData(
val width: Int,
val height: Int,
val channels: Int,
val buffer: FloatBuffer,
)
} }

View File

@@ -60,6 +60,8 @@ class MainViewModel(
private var _pageToValidate = MutableStateFlow<Bitmap?>(null) private var _pageToValidate = MutableStateFlow<Bitmap?>(null)
val pageToValidate: StateFlow<Bitmap?> = _pageToValidate.asStateFlow() val pageToValidate: StateFlow<Bitmap?> = _pageToValidate.asStateFlow()
var liveAnalysisEnabled = true
init { init {
viewModelScope.launch { viewModelScope.launch {
imageSegmentationService.initialize() imageSegmentationService.initialize()
@@ -80,6 +82,11 @@ class MainViewModel(
} }
fun segment(imageProxy: ImageProxy) { fun segment(imageProxy: ImageProxy) {
if (!liveAnalysisEnabled) {
imageProxy.close()
return
}
viewModelScope.launch { viewModelScope.launch {
imageSegmentationService.runSegmentationAndEmit( imageSegmentationService.runSegmentationAndEmit(
imageProxy.toBitmap(), imageProxy.toBitmap(),

View File

@@ -131,6 +131,7 @@ fun CameraScreen(
pageCount = viewModel.pageCount(), pageCount = viewModel.pageCount(),
liveAnalysisState, liveAnalysisState,
onCapture = { onCapture = {
viewModel.liveAnalysisEnabled = false
showPageDialog.value = true showPageDialog.value = true
isProcessing.value = true isProcessing.value = true
captureController.takePicture( captureController.takePicture(
@@ -138,6 +139,7 @@ fun CameraScreen(
if (imageProxy != null) { if (imageProxy != null) {
viewModel.processCapturedImageThen(imageProxy) { viewModel.processCapturedImageThen(imageProxy) {
isProcessing.value = false isProcessing.value = false
viewModel.liveAnalysisEnabled = true
} }
} else { } else {
Log.e("MyScan", "Error during image capture") Log.e("MyScan", "Error during image capture")