Switch to a pytorch-based model detecting documents

This commit is contained in:
Pierre-Yves Nicolas
2025-05-29 21:02:00 +02:00
parent 3457a85044
commit 0c3a666502
3 changed files with 64 additions and 42 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

@@ -4,8 +4,10 @@ import android.content.Context
import android.graphics.Bitmap
import android.graphics.Bitmap.createBitmap
import android.graphics.Color.argb
import android.graphics.Matrix
import android.os.SystemClock
import android.util.Log
import androidx.core.graphics.scale
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
@@ -16,13 +18,9 @@ import kotlinx.coroutines.isActive
import kotlinx.coroutines.withContext
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ColorSpaceType
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.ImageProperties
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.Rot90Op
import java.nio.ByteBuffer
import java.nio.FloatBuffer
@@ -44,7 +42,7 @@ class ImageSegmentationService(private val context: Context) {
suspend fun initialize() {
interpreter = try {
val litertBuffer = FileUtil.loadMappedFile(context, "deeplab_v3.tflite")
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
Log.i(TAG, "Loaded LiteRT model")
val options = Interpreter.Options().apply {
numThreads = 2
@@ -63,21 +61,23 @@ class ImageSegmentationService(private val context: Context) {
if (interpreter == null) return@withContext
val startTime = SystemClock.uptimeMillis()
val rotation = -rotationDegrees / 90
val (_, h, w, _) = interpreter?.getOutputTensor(0)?.shape() ?: return@withContext
val imageProcessor =
ImageProcessor
.Builder()
.add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR))
.add(Rot90Op(rotation))
.add(NormalizeOp(127.5f, 127.5f))
.build()
val (_, _, h, w) = interpreter?.getInputTensor(0)?.shape() ?: return@withContext
val dataType = interpreter?.getInputTensor(0)?.dataType()
Log.i(TAG, "segment, input shape: ${interpreter!!.getInputTensor(0).shape().asList()} data type=${dataType}")
// Preprocess manually into CHW float buffer
val inputBuffer = bitmapToCHWFloatBuffer(bitmap, width = w, height = h, rotationDegrees)
val (_, cOut, hOut, wOut) = interpreter!!.getOutputTensor(0).shape()
val outputBuffer = FloatBuffer.allocate(cOut * hOut * wOut)
// Run inference
outputBuffer.rewind()
interpreter?.run(inputBuffer, outputBuffer)
// Preprocess the image and convert it into a TensorImage for segmentation.
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap))
val segmentResult = segment(tensorImage)
val inferenceTime = SystemClock.uptimeMillis() - startTime
if (isActive) {
val segmentResult = processOutputBuffer(outputBuffer, wOut, hOut, cOut)
_segmentation.value = SegmentationResult(segmentResult, inferenceTime)
}
}
@@ -87,14 +87,46 @@ class ImageSegmentationService(private val context: Context) {
}
}
fun bitmapToCHWFloatBuffer(bitmap: Bitmap, width: Int, height: Int, rotationDegrees: Int): FloatBuffer {
val rotatedBitmap = if (rotationDegrees != 0) {
val matrix = Matrix().apply { postRotate(rotationDegrees.toFloat()) }
createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
} else {
bitmap
}
private fun segment(tensorImage: TensorImage): Segmentation {
val (_, h, w, c) = interpreter!!.getOutputTensor(0).shape()
val outputBuffer = FloatBuffer.allocate(h * w * c)
val resized = rotatedBitmap.scale(width, height)
val buffer = FloatBuffer.allocate(1 * 3 * height * width)
buffer.rewind()
outputBuffer.rewind()
interpreter?.run(tensorImage.tensorBuffer.buffer, outputBuffer)
val mean = floatArrayOf(0.4611f, 0.4359f, 0.3905f)
val std = floatArrayOf(0.2193f, 0.2150f, 0.2109f)
val pixels = IntArray(width * height)
resized.getPixels(pixels, 0, width, 0, 0, width, height)
// Fill buffer in CHW order
for (c in 0..2) {
for (i in 0 until height) {
for (j in 0 until width) {
val pixel = pixels[i * width + j]
val value = when (c) {
0 -> (pixel shr 16 and 0xFF) // R
1 -> (pixel shr 8 and 0xFF) // G
2 -> (pixel and 0xFF) // B
else -> 0
}
val normalized = (value / 255f - mean[c]) / std[c]
buffer.put(normalized)
}
}
}
buffer.rewind()
return buffer
}
private fun processOutputBuffer(outputBuffer: FloatBuffer, w: Int, h: Int, c: Int): Segmentation {
outputBuffer.rewind()
val inferenceData =
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
@@ -116,15 +148,15 @@ class ImageSegmentationService(private val context: Context) {
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height)
for (i in 0 until inferenceData.height) {
for (j in 0 until inferenceData.width) {
val offset = inferenceData.channels * (i * inferenceData.width + j)
var maxIndex = 0
var maxValue = inferenceData.buffer.get(offset)
var maxValue = inferenceData.buffer.get(i * inferenceData.width + j)
for (index in 1 until inferenceData.channels) {
if (inferenceData.buffer.get(offset + index) > maxValue) {
maxValue = inferenceData.buffer.get(offset + index)
maxIndex = index
for (c in 1 until inferenceData.channels) {
val value = inferenceData.buffer.get(
c * inferenceData.height * inferenceData.width + i * inferenceData.width + j)
if (value > maxValue) {
maxValue = value
maxIndex = c
}
}

View File

@@ -1,6 +1,7 @@
package org.mydomain.myscan
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import androidx.camera.core.ImageProxy
import androidx.lifecycle.ViewModel
@@ -34,7 +35,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
.filterNotNull()
.map {
UiState(
"Found ${numberOfObjectsDetected(it.segmentation)} objects!",
"Inference done",
it.inferenceTime,
it.segmentation.toBitmap())
}
@@ -45,18 +46,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
}
}
fun numberOfObjectsDetected(segmentation: ImageSegmentationService.Segmentation) : Int {
val tensor = segmentation.mask;
val buffer = tensor.buffer
val uniqueValues = HashSet<Int>()
for (i in 0..tensor.width * tensor.height - 1) {
uniqueValues.add(buffer[i].toInt())
}
return uniqueValues.size - 1;
}
fun segment(imageProxy: ImageProxy) {
Log.d("MyScan", "MainViewModel.Calling segment")
viewModelScope.launch {
imageSegmentationService.runSegmentation(
imageProxy.toBitmap(),