Switch to a pytorch-based model detecting documents

This commit is contained in:
Pierre-Yves Nicolas
2025-05-29 21:02:00 +02:00
parent 3457a85044
commit 0c3a666502
3 changed files with 64 additions and 42 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

View File

@@ -4,8 +4,10 @@ import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.graphics.Bitmap.createBitmap import android.graphics.Bitmap.createBitmap
import android.graphics.Color.argb import android.graphics.Color.argb
import android.graphics.Matrix
import android.os.SystemClock import android.os.SystemClock
import android.util.Log import android.util.Log
import androidx.core.graphics.scale
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
@@ -16,13 +18,9 @@ import kotlinx.coroutines.isActive
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import org.tensorflow.lite.Interpreter import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.FileUtil import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ColorSpaceType import org.tensorflow.lite.support.image.ColorSpaceType
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.ImageProperties import org.tensorflow.lite.support.image.ImageProperties
import org.tensorflow.lite.support.image.TensorImage import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.Rot90Op
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.FloatBuffer import java.nio.FloatBuffer
@@ -44,7 +42,7 @@ class ImageSegmentationService(private val context: Context) {
suspend fun initialize() { suspend fun initialize() {
interpreter = try { interpreter = try {
val litertBuffer = FileUtil.loadMappedFile(context, "deeplab_v3.tflite") val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
Log.i(TAG, "Loaded LiteRT model") Log.i(TAG, "Loaded LiteRT model")
val options = Interpreter.Options().apply { val options = Interpreter.Options().apply {
numThreads = 2 numThreads = 2
@@ -63,21 +61,23 @@ class ImageSegmentationService(private val context: Context) {
if (interpreter == null) return@withContext if (interpreter == null) return@withContext
val startTime = SystemClock.uptimeMillis() val startTime = SystemClock.uptimeMillis()
val rotation = -rotationDegrees / 90 val (_, _, h, w) = interpreter?.getInputTensor(0)?.shape() ?: return@withContext
val (_, h, w, _) = interpreter?.getOutputTensor(0)?.shape() ?: return@withContext val dataType = interpreter?.getInputTensor(0)?.dataType()
val imageProcessor = Log.i(TAG, "segment, input shape: ${interpreter!!.getInputTensor(0).shape().asList()} data type=${dataType}")
ImageProcessor
.Builder() // Preprocess manually into CHW float buffer
.add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR)) val inputBuffer = bitmapToCHWFloatBuffer(bitmap, width = w, height = h, rotationDegrees)
.add(Rot90Op(rotation))
.add(NormalizeOp(127.5f, 127.5f)) val (_, cOut, hOut, wOut) = interpreter!!.getOutputTensor(0).shape()
.build() val outputBuffer = FloatBuffer.allocate(cOut * hOut * wOut)
// Run inference
outputBuffer.rewind()
interpreter?.run(inputBuffer, outputBuffer)
// Preprocess the image and convert it into a TensorImage for segmentation.
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap))
val segmentResult = segment(tensorImage)
val inferenceTime = SystemClock.uptimeMillis() - startTime val inferenceTime = SystemClock.uptimeMillis() - startTime
if (isActive) { if (isActive) {
val segmentResult = processOutputBuffer(outputBuffer, wOut, hOut, cOut)
_segmentation.value = SegmentationResult(segmentResult, inferenceTime) _segmentation.value = SegmentationResult(segmentResult, inferenceTime)
} }
} }
@@ -87,14 +87,46 @@ class ImageSegmentationService(private val context: Context) {
} }
} }
fun bitmapToCHWFloatBuffer(bitmap: Bitmap, width: Int, height: Int, rotationDegrees: Int): FloatBuffer {
val rotatedBitmap = if (rotationDegrees != 0) {
val matrix = Matrix().apply { postRotate(rotationDegrees.toFloat()) }
createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
} else {
bitmap
}
private fun segment(tensorImage: TensorImage): Segmentation { val resized = rotatedBitmap.scale(width, height)
val (_, h, w, c) = interpreter!!.getOutputTensor(0).shape() val buffer = FloatBuffer.allocate(1 * 3 * height * width)
val outputBuffer = FloatBuffer.allocate(h * w * c) buffer.rewind()
outputBuffer.rewind() val mean = floatArrayOf(0.4611f, 0.4359f, 0.3905f)
interpreter?.run(tensorImage.tensorBuffer.buffer, outputBuffer) val std = floatArrayOf(0.2193f, 0.2150f, 0.2109f)
val pixels = IntArray(width * height)
resized.getPixels(pixels, 0, width, 0, 0, width, height)
// Fill buffer in CHW order
for (c in 0..2) {
for (i in 0 until height) {
for (j in 0 until width) {
val pixel = pixels[i * width + j]
val value = when (c) {
0 -> (pixel shr 16 and 0xFF) // R
1 -> (pixel shr 8 and 0xFF) // G
2 -> (pixel and 0xFF) // B
else -> 0
}
val normalized = (value / 255f - mean[c]) / std[c]
buffer.put(normalized)
}
}
}
buffer.rewind()
return buffer
}
private fun processOutputBuffer(outputBuffer: FloatBuffer, w: Int, h: Int, c: Int): Segmentation {
outputBuffer.rewind() outputBuffer.rewind()
val inferenceData = val inferenceData =
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer) InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
@@ -116,15 +148,15 @@ class ImageSegmentationService(private val context: Context) {
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height) val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height)
for (i in 0 until inferenceData.height) { for (i in 0 until inferenceData.height) {
for (j in 0 until inferenceData.width) { for (j in 0 until inferenceData.width) {
val offset = inferenceData.channels * (i * inferenceData.width + j)
var maxIndex = 0 var maxIndex = 0
var maxValue = inferenceData.buffer.get(offset) var maxValue = inferenceData.buffer.get(i * inferenceData.width + j)
for (index in 1 until inferenceData.channels) { for (c in 1 until inferenceData.channels) {
if (inferenceData.buffer.get(offset + index) > maxValue) { val value = inferenceData.buffer.get(
maxValue = inferenceData.buffer.get(offset + index) c * inferenceData.height * inferenceData.width + i * inferenceData.width + j)
maxIndex = index if (value > maxValue) {
maxValue = value
maxIndex = c
} }
} }
@@ -164,4 +196,4 @@ class ImageSegmentationService(private val context: Context) {
val channels: Int, val channels: Int,
val buffer: FloatBuffer, val buffer: FloatBuffer,
) )
} }

View File

@@ -1,6 +1,7 @@
package org.mydomain.myscan package org.mydomain.myscan
import android.content.Context import android.content.Context
import android.graphics.Bitmap
import android.util.Log import android.util.Log
import androidx.camera.core.ImageProxy import androidx.camera.core.ImageProxy
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
@@ -34,7 +35,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
.filterNotNull() .filterNotNull()
.map { .map {
UiState( UiState(
"Found ${numberOfObjectsDetected(it.segmentation)} objects!", "Inference done",
it.inferenceTime, it.inferenceTime,
it.segmentation.toBitmap()) it.segmentation.toBitmap())
} }
@@ -45,18 +46,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
} }
} }
fun numberOfObjectsDetected(segmentation: ImageSegmentationService.Segmentation) : Int {
val tensor = segmentation.mask;
val buffer = tensor.buffer
val uniqueValues = HashSet<Int>()
for (i in 0..tensor.width * tensor.height - 1) {
uniqueValues.add(buffer[i].toInt())
}
return uniqueValues.size - 1;
}
fun segment(imageProxy: ImageProxy) { fun segment(imageProxy: ImageProxy) {
Log.d("MyScan", "MainViewModel.Calling segment")
viewModelScope.launch { viewModelScope.launch {
imageSegmentationService.runSegmentation( imageSegmentationService.runSegmentation(
imageProxy.toBitmap(), imageProxy.toBitmap(),