Switch to a pytorch-based model detecting documents
This commit is contained in:
BIN
app/src/main/assets/IMG_20250329_132414_900.jpg
Normal file
BIN
app/src/main/assets/IMG_20250329_132414_900.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 80 KiB |
@@ -4,8 +4,10 @@ import android.content.Context
|
|||||||
import android.graphics.Bitmap
|
import android.graphics.Bitmap
|
||||||
import android.graphics.Bitmap.createBitmap
|
import android.graphics.Bitmap.createBitmap
|
||||||
import android.graphics.Color.argb
|
import android.graphics.Color.argb
|
||||||
|
import android.graphics.Matrix
|
||||||
import android.os.SystemClock
|
import android.os.SystemClock
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
|
import androidx.core.graphics.scale
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
@@ -16,13 +18,9 @@ import kotlinx.coroutines.isActive
|
|||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import org.tensorflow.lite.Interpreter
|
import org.tensorflow.lite.Interpreter
|
||||||
import org.tensorflow.lite.support.common.FileUtil
|
import org.tensorflow.lite.support.common.FileUtil
|
||||||
import org.tensorflow.lite.support.common.ops.NormalizeOp
|
|
||||||
import org.tensorflow.lite.support.image.ColorSpaceType
|
import org.tensorflow.lite.support.image.ColorSpaceType
|
||||||
import org.tensorflow.lite.support.image.ImageProcessor
|
|
||||||
import org.tensorflow.lite.support.image.ImageProperties
|
import org.tensorflow.lite.support.image.ImageProperties
|
||||||
import org.tensorflow.lite.support.image.TensorImage
|
import org.tensorflow.lite.support.image.TensorImage
|
||||||
import org.tensorflow.lite.support.image.ops.ResizeOp
|
|
||||||
import org.tensorflow.lite.support.image.ops.Rot90Op
|
|
||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
import java.nio.FloatBuffer
|
import java.nio.FloatBuffer
|
||||||
|
|
||||||
@@ -44,7 +42,7 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
|
|
||||||
suspend fun initialize() {
|
suspend fun initialize() {
|
||||||
interpreter = try {
|
interpreter = try {
|
||||||
val litertBuffer = FileUtil.loadMappedFile(context, "deeplab_v3.tflite")
|
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
|
||||||
Log.i(TAG, "Loaded LiteRT model")
|
Log.i(TAG, "Loaded LiteRT model")
|
||||||
val options = Interpreter.Options().apply {
|
val options = Interpreter.Options().apply {
|
||||||
numThreads = 2
|
numThreads = 2
|
||||||
@@ -63,21 +61,23 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
if (interpreter == null) return@withContext
|
if (interpreter == null) return@withContext
|
||||||
val startTime = SystemClock.uptimeMillis()
|
val startTime = SystemClock.uptimeMillis()
|
||||||
|
|
||||||
val rotation = -rotationDegrees / 90
|
val (_, _, h, w) = interpreter?.getInputTensor(0)?.shape() ?: return@withContext
|
||||||
val (_, h, w, _) = interpreter?.getOutputTensor(0)?.shape() ?: return@withContext
|
val dataType = interpreter?.getInputTensor(0)?.dataType()
|
||||||
val imageProcessor =
|
Log.i(TAG, "segment, input shape: ${interpreter!!.getInputTensor(0).shape().asList()} data type=${dataType}")
|
||||||
ImageProcessor
|
|
||||||
.Builder()
|
// Preprocess manually into CHW float buffer
|
||||||
.add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR))
|
val inputBuffer = bitmapToCHWFloatBuffer(bitmap, width = w, height = h, rotationDegrees)
|
||||||
.add(Rot90Op(rotation))
|
|
||||||
.add(NormalizeOp(127.5f, 127.5f))
|
val (_, cOut, hOut, wOut) = interpreter!!.getOutputTensor(0).shape()
|
||||||
.build()
|
val outputBuffer = FloatBuffer.allocate(cOut * hOut * wOut)
|
||||||
|
|
||||||
|
// Run inference
|
||||||
|
outputBuffer.rewind()
|
||||||
|
interpreter?.run(inputBuffer, outputBuffer)
|
||||||
|
|
||||||
// Preprocess the image and convert it into a TensorImage for segmentation.
|
|
||||||
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap))
|
|
||||||
val segmentResult = segment(tensorImage)
|
|
||||||
val inferenceTime = SystemClock.uptimeMillis() - startTime
|
val inferenceTime = SystemClock.uptimeMillis() - startTime
|
||||||
if (isActive) {
|
if (isActive) {
|
||||||
|
val segmentResult = processOutputBuffer(outputBuffer, wOut, hOut, cOut)
|
||||||
_segmentation.value = SegmentationResult(segmentResult, inferenceTime)
|
_segmentation.value = SegmentationResult(segmentResult, inferenceTime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -87,14 +87,46 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun bitmapToCHWFloatBuffer(bitmap: Bitmap, width: Int, height: Int, rotationDegrees: Int): FloatBuffer {
|
||||||
|
val rotatedBitmap = if (rotationDegrees != 0) {
|
||||||
|
val matrix = Matrix().apply { postRotate(rotationDegrees.toFloat()) }
|
||||||
|
createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
|
||||||
|
} else {
|
||||||
|
bitmap
|
||||||
|
}
|
||||||
|
|
||||||
private fun segment(tensorImage: TensorImage): Segmentation {
|
val resized = rotatedBitmap.scale(width, height)
|
||||||
val (_, h, w, c) = interpreter!!.getOutputTensor(0).shape()
|
val buffer = FloatBuffer.allocate(1 * 3 * height * width)
|
||||||
val outputBuffer = FloatBuffer.allocate(h * w * c)
|
buffer.rewind()
|
||||||
|
|
||||||
outputBuffer.rewind()
|
val mean = floatArrayOf(0.4611f, 0.4359f, 0.3905f)
|
||||||
interpreter?.run(tensorImage.tensorBuffer.buffer, outputBuffer)
|
val std = floatArrayOf(0.2193f, 0.2150f, 0.2109f)
|
||||||
|
|
||||||
|
val pixels = IntArray(width * height)
|
||||||
|
resized.getPixels(pixels, 0, width, 0, 0, width, height)
|
||||||
|
|
||||||
|
// Fill buffer in CHW order
|
||||||
|
for (c in 0..2) {
|
||||||
|
for (i in 0 until height) {
|
||||||
|
for (j in 0 until width) {
|
||||||
|
val pixel = pixels[i * width + j]
|
||||||
|
val value = when (c) {
|
||||||
|
0 -> (pixel shr 16 and 0xFF) // R
|
||||||
|
1 -> (pixel shr 8 and 0xFF) // G
|
||||||
|
2 -> (pixel and 0xFF) // B
|
||||||
|
else -> 0
|
||||||
|
}
|
||||||
|
val normalized = (value / 255f - mean[c]) / std[c]
|
||||||
|
buffer.put(normalized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.rewind()
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun processOutputBuffer(outputBuffer: FloatBuffer, w: Int, h: Int, c: Int): Segmentation {
|
||||||
outputBuffer.rewind()
|
outputBuffer.rewind()
|
||||||
val inferenceData =
|
val inferenceData =
|
||||||
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
|
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
|
||||||
@@ -116,15 +148,15 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height)
|
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height)
|
||||||
for (i in 0 until inferenceData.height) {
|
for (i in 0 until inferenceData.height) {
|
||||||
for (j in 0 until inferenceData.width) {
|
for (j in 0 until inferenceData.width) {
|
||||||
val offset = inferenceData.channels * (i * inferenceData.width + j)
|
|
||||||
|
|
||||||
var maxIndex = 0
|
var maxIndex = 0
|
||||||
var maxValue = inferenceData.buffer.get(offset)
|
var maxValue = inferenceData.buffer.get(i * inferenceData.width + j)
|
||||||
|
|
||||||
for (index in 1 until inferenceData.channels) {
|
for (c in 1 until inferenceData.channels) {
|
||||||
if (inferenceData.buffer.get(offset + index) > maxValue) {
|
val value = inferenceData.buffer.get(
|
||||||
maxValue = inferenceData.buffer.get(offset + index)
|
c * inferenceData.height * inferenceData.width + i * inferenceData.width + j)
|
||||||
maxIndex = index
|
if (value > maxValue) {
|
||||||
|
maxValue = value
|
||||||
|
maxIndex = c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package org.mydomain.myscan
|
package org.mydomain.myscan
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
|
import android.graphics.Bitmap
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.camera.core.ImageProxy
|
import androidx.camera.core.ImageProxy
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
@@ -34,7 +35,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
|
|||||||
.filterNotNull()
|
.filterNotNull()
|
||||||
.map {
|
.map {
|
||||||
UiState(
|
UiState(
|
||||||
"Found ${numberOfObjectsDetected(it.segmentation)} objects!",
|
"Inference done",
|
||||||
it.inferenceTime,
|
it.inferenceTime,
|
||||||
it.segmentation.toBitmap())
|
it.segmentation.toBitmap())
|
||||||
}
|
}
|
||||||
@@ -45,18 +46,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun numberOfObjectsDetected(segmentation: ImageSegmentationService.Segmentation) : Int {
|
|
||||||
val tensor = segmentation.mask;
|
|
||||||
val buffer = tensor.buffer
|
|
||||||
val uniqueValues = HashSet<Int>()
|
|
||||||
for (i in 0..tensor.width * tensor.height - 1) {
|
|
||||||
uniqueValues.add(buffer[i].toInt())
|
|
||||||
}
|
|
||||||
return uniqueValues.size - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
fun segment(imageProxy: ImageProxy) {
|
fun segment(imageProxy: ImageProxy) {
|
||||||
Log.d("MyScan", "MainViewModel.Calling segment")
|
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
imageSegmentationService.runSegmentation(
|
imageSegmentationService.runSegmentation(
|
||||||
imageProxy.toBitmap(),
|
imageProxy.toBitmap(),
|
||||||
|
|||||||
Reference in New Issue
Block a user