Switch to a pytorch-based model detecting documents
This commit is contained in:
BIN
app/src/main/assets/IMG_20250329_132414_900.jpg
Normal file
BIN
app/src/main/assets/IMG_20250329_132414_900.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 80 KiB |
@@ -4,8 +4,10 @@ import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.Bitmap.createBitmap
|
||||
import android.graphics.Color.argb
|
||||
import android.graphics.Matrix
|
||||
import android.os.SystemClock
|
||||
import android.util.Log
|
||||
import androidx.core.graphics.scale
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
@@ -16,13 +18,9 @@ import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import org.tensorflow.lite.support.common.FileUtil
|
||||
import org.tensorflow.lite.support.common.ops.NormalizeOp
|
||||
import org.tensorflow.lite.support.image.ColorSpaceType
|
||||
import org.tensorflow.lite.support.image.ImageProcessor
|
||||
import org.tensorflow.lite.support.image.ImageProperties
|
||||
import org.tensorflow.lite.support.image.TensorImage
|
||||
import org.tensorflow.lite.support.image.ops.ResizeOp
|
||||
import org.tensorflow.lite.support.image.ops.Rot90Op
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.FloatBuffer
|
||||
|
||||
@@ -44,7 +42,7 @@ class ImageSegmentationService(private val context: Context) {
|
||||
|
||||
suspend fun initialize() {
|
||||
interpreter = try {
|
||||
val litertBuffer = FileUtil.loadMappedFile(context, "deeplab_v3.tflite")
|
||||
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
|
||||
Log.i(TAG, "Loaded LiteRT model")
|
||||
val options = Interpreter.Options().apply {
|
||||
numThreads = 2
|
||||
@@ -63,21 +61,23 @@ class ImageSegmentationService(private val context: Context) {
|
||||
if (interpreter == null) return@withContext
|
||||
val startTime = SystemClock.uptimeMillis()
|
||||
|
||||
val rotation = -rotationDegrees / 90
|
||||
val (_, h, w, _) = interpreter?.getOutputTensor(0)?.shape() ?: return@withContext
|
||||
val imageProcessor =
|
||||
ImageProcessor
|
||||
.Builder()
|
||||
.add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR))
|
||||
.add(Rot90Op(rotation))
|
||||
.add(NormalizeOp(127.5f, 127.5f))
|
||||
.build()
|
||||
val (_, _, h, w) = interpreter?.getInputTensor(0)?.shape() ?: return@withContext
|
||||
val dataType = interpreter?.getInputTensor(0)?.dataType()
|
||||
Log.i(TAG, "segment, input shape: ${interpreter!!.getInputTensor(0).shape().asList()} data type=${dataType}")
|
||||
|
||||
// Preprocess manually into CHW float buffer
|
||||
val inputBuffer = bitmapToCHWFloatBuffer(bitmap, width = w, height = h, rotationDegrees)
|
||||
|
||||
val (_, cOut, hOut, wOut) = interpreter!!.getOutputTensor(0).shape()
|
||||
val outputBuffer = FloatBuffer.allocate(cOut * hOut * wOut)
|
||||
|
||||
// Run inference
|
||||
outputBuffer.rewind()
|
||||
interpreter?.run(inputBuffer, outputBuffer)
|
||||
|
||||
// Preprocess the image and convert it into a TensorImage for segmentation.
|
||||
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap))
|
||||
val segmentResult = segment(tensorImage)
|
||||
val inferenceTime = SystemClock.uptimeMillis() - startTime
|
||||
if (isActive) {
|
||||
val segmentResult = processOutputBuffer(outputBuffer, wOut, hOut, cOut)
|
||||
_segmentation.value = SegmentationResult(segmentResult, inferenceTime)
|
||||
}
|
||||
}
|
||||
@@ -87,14 +87,46 @@ class ImageSegmentationService(private val context: Context) {
|
||||
}
|
||||
}
|
||||
|
||||
fun bitmapToCHWFloatBuffer(bitmap: Bitmap, width: Int, height: Int, rotationDegrees: Int): FloatBuffer {
|
||||
val rotatedBitmap = if (rotationDegrees != 0) {
|
||||
val matrix = Matrix().apply { postRotate(rotationDegrees.toFloat()) }
|
||||
createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
|
||||
} else {
|
||||
bitmap
|
||||
}
|
||||
|
||||
private fun segment(tensorImage: TensorImage): Segmentation {
|
||||
val (_, h, w, c) = interpreter!!.getOutputTensor(0).shape()
|
||||
val outputBuffer = FloatBuffer.allocate(h * w * c)
|
||||
val resized = rotatedBitmap.scale(width, height)
|
||||
val buffer = FloatBuffer.allocate(1 * 3 * height * width)
|
||||
buffer.rewind()
|
||||
|
||||
outputBuffer.rewind()
|
||||
interpreter?.run(tensorImage.tensorBuffer.buffer, outputBuffer)
|
||||
val mean = floatArrayOf(0.4611f, 0.4359f, 0.3905f)
|
||||
val std = floatArrayOf(0.2193f, 0.2150f, 0.2109f)
|
||||
|
||||
val pixels = IntArray(width * height)
|
||||
resized.getPixels(pixels, 0, width, 0, 0, width, height)
|
||||
|
||||
// Fill buffer in CHW order
|
||||
for (c in 0..2) {
|
||||
for (i in 0 until height) {
|
||||
for (j in 0 until width) {
|
||||
val pixel = pixels[i * width + j]
|
||||
val value = when (c) {
|
||||
0 -> (pixel shr 16 and 0xFF) // R
|
||||
1 -> (pixel shr 8 and 0xFF) // G
|
||||
2 -> (pixel and 0xFF) // B
|
||||
else -> 0
|
||||
}
|
||||
val normalized = (value / 255f - mean[c]) / std[c]
|
||||
buffer.put(normalized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buffer.rewind()
|
||||
return buffer
|
||||
}
|
||||
|
||||
private fun processOutputBuffer(outputBuffer: FloatBuffer, w: Int, h: Int, c: Int): Segmentation {
|
||||
outputBuffer.rewind()
|
||||
val inferenceData =
|
||||
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
|
||||
@@ -116,15 +148,15 @@ class ImageSegmentationService(private val context: Context) {
|
||||
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height)
|
||||
for (i in 0 until inferenceData.height) {
|
||||
for (j in 0 until inferenceData.width) {
|
||||
val offset = inferenceData.channels * (i * inferenceData.width + j)
|
||||
|
||||
var maxIndex = 0
|
||||
var maxValue = inferenceData.buffer.get(offset)
|
||||
var maxValue = inferenceData.buffer.get(i * inferenceData.width + j)
|
||||
|
||||
for (index in 1 until inferenceData.channels) {
|
||||
if (inferenceData.buffer.get(offset + index) > maxValue) {
|
||||
maxValue = inferenceData.buffer.get(offset + index)
|
||||
maxIndex = index
|
||||
for (c in 1 until inferenceData.channels) {
|
||||
val value = inferenceData.buffer.get(
|
||||
c * inferenceData.height * inferenceData.width + i * inferenceData.width + j)
|
||||
if (value > maxValue) {
|
||||
maxValue = value
|
||||
maxIndex = c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,4 +196,4 @@ class ImageSegmentationService(private val context: Context) {
|
||||
val channels: Int,
|
||||
val buffer: FloatBuffer,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.mydomain.myscan
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.util.Log
|
||||
import androidx.camera.core.ImageProxy
|
||||
import androidx.lifecycle.ViewModel
|
||||
@@ -34,7 +35,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
|
||||
.filterNotNull()
|
||||
.map {
|
||||
UiState(
|
||||
"Found ${numberOfObjectsDetected(it.segmentation)} objects!",
|
||||
"Inference done",
|
||||
it.inferenceTime,
|
||||
it.segmentation.toBitmap())
|
||||
}
|
||||
@@ -45,18 +46,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
|
||||
}
|
||||
}
|
||||
|
||||
fun numberOfObjectsDetected(segmentation: ImageSegmentationService.Segmentation) : Int {
|
||||
val tensor = segmentation.mask;
|
||||
val buffer = tensor.buffer
|
||||
val uniqueValues = HashSet<Int>()
|
||||
for (i in 0..tensor.width * tensor.height - 1) {
|
||||
uniqueValues.add(buffer[i].toInt())
|
||||
}
|
||||
return uniqueValues.size - 1;
|
||||
}
|
||||
|
||||
fun segment(imageProxy: ImageProxy) {
|
||||
Log.d("MyScan", "MainViewModel.Calling segment")
|
||||
viewModelScope.launch {
|
||||
imageSegmentationService.runSegmentation(
|
||||
imageProxy.toBitmap(),
|
||||
|
||||
Reference in New Issue
Block a user