Display a quadrilateral approximation of the detected document

This commit is contained in:
Pierre-Yves Nicolas
2025-05-30 14:26:28 +02:00
parent 0c3a666502
commit 02cc4a7627
9 changed files with 179 additions and 27 deletions

View File

@@ -58,6 +58,7 @@ dependencies {
implementation(libs.litert) implementation(libs.litert)
implementation(libs.litert.support) implementation(libs.litert.support)
implementation(libs.litert.metadata) implementation(libs.litert.metadata)
implementation(libs.opencv)
testImplementation(libs.junit) testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit) androidTestImplementation(libs.androidx.junit)

View File

@@ -0,0 +1,49 @@
package org.mydomain.myscan
import android.graphics.Bitmap
import org.opencv.android.Utils
import org.opencv.core.Mat
import org.opencv.core.MatOfPoint
import org.opencv.core.MatOfPoint2f
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import kotlin.math.abs
fun detectDocumentQuad(mask: Bitmap): Quad? {
val mat = Mat()
Utils.bitmapToMat(mask, mat)
val gray = Mat()
Imgproc.cvtColor(mat, gray, Imgproc.COLOR_BGR2GRAY)
val blurred = Mat()
Imgproc.GaussianBlur(gray, blurred, Size(5.0, 5.0), 0.0)
val edges = Mat()
Imgproc.Canny(blurred, edges, 75.0, 200.0)
val contours = mutableListOf<MatOfPoint>()
val hierarchy = Mat()
Imgproc.findContours(edges, contours, hierarchy, Imgproc.RETR_LIST, Imgproc.CHAIN_APPROX_SIMPLE)
var biggest: MatOfPoint2f? = null
var maxArea = 0.0
for (contour in contours) {
val contour2f = MatOfPoint2f(*contour.toArray())
val peri = Imgproc.arcLength(contour2f, true)
val approx = MatOfPoint2f()
Imgproc.approxPolyDP(contour2f, approx, 0.02 * peri, true)
if (approx.total() == 4L) {
val area = abs(Imgproc.contourArea(approx))
if (area > maxArea) {
maxArea = area
biggest = approx
}
}
}
val vertices = biggest?.toList()?.map { Point(it.x.toInt(), it.y.toInt()) }
return createQuad(vertices)
}

View File

@@ -0,0 +1,52 @@
package org.mydomain.myscan
import kotlin.math.atan2
data class Point(val x: Int, val y: Int)
data class Line(val from: Point, val to: Point)
data class Quad(
val topLeft: Point,
val topRight: Point,
val bottomRight: Point,
val bottomLeft: Point
) {
fun edges(): List<Line> {
return listOf(
Line(topLeft, topRight),
Line(topRight, bottomRight),
Line(bottomRight, bottomLeft),
Line(bottomLeft, topLeft))
}
}
fun createQuad(vertices: List<Point>?): Quad? {
if (vertices == null || vertices.size != 4) return null
// Centroid of the points
val cx = vertices.map { it.x }.average()
val cy = vertices.map { it.y }.average()
// Sort by angle from centroid (clockwise)
val sorted = vertices.sortedWith(compareBy<Point> {
atan2((it.y - cy).toDouble(), (it.x - cx).toDouble())
})
return Quad(sorted[0], sorted[1], sorted[2], sorted[3])
}
fun Quad.scaledTo(fromWidth: Int, fromHeight: Int, toWidth: Int, toHeight: Int): Quad {
val scaleX = toWidth.toFloat() / fromWidth
val scaleY = toHeight.toFloat() / fromHeight
return Quad(
topLeft = topLeft.scaled(scaleX, scaleY),
topRight = topRight.scaled(scaleX, scaleY),
bottomRight = bottomRight.scaled(scaleX, scaleY),
bottomLeft = bottomLeft.scaled(scaleX, scaleY)
)
}
fun Point.scaled(scaleX: Float, scaleY: Float): Point {
return Point((x * scaleX).toInt(), (y * scaleY).toInt())
}

View File

@@ -3,7 +3,7 @@ package org.mydomain.myscan
import android.content.Context import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.graphics.Bitmap.createBitmap import android.graphics.Bitmap.createBitmap
import android.graphics.Color.argb import android.graphics.Color
import android.graphics.Matrix import android.graphics.Matrix
import android.os.SystemClock import android.os.SystemClock
import android.util.Log import android.util.Log
@@ -168,23 +168,21 @@ class ImageSegmentationService(private val context: Context) {
} }
data class Segmentation(val mask: TensorImage) { data class Segmentation(val mask: TensorImage) {
fun toBitmap(): Bitmap { fun toBinaryMask(): Bitmap {
val width = mask.width val width = mask.width
val height = mask.height val height = mask.height
val pixels = IntArray(width * height) val pixels = IntArray(width * height)
val green = argb(128, 0, 255, 0)
for (i in 0 until height) { for (i in 0 until height) {
for (j in 0 until width) { for (j in 0 until width) {
val index = i * width + j val index = i * width + j
val classId = mask.buffer[index].toInt() and 0xFF // Unsigned byte val classId = mask.buffer[index].toInt() and 0xFF // Unsigned byte
pixels[index] = if (classId == 0) 0 else green pixels[index] = if (classId == 0) Color.BLACK else Color.WHITE
} }
} }
return createBitmap(pixels, width, height, Bitmap.Config.ARGB_8888) return createBitmap(pixels, width, height, Bitmap.Config.ARGB_8888)
} }
} }
data class SegmentationResult( data class SegmentationResult(
val segmentation: Segmentation, val segmentation: Segmentation,
val inferenceTime: Long val inferenceTime: Long

View File

@@ -1,6 +1,7 @@
package org.mydomain.myscan package org.mydomain.myscan
import android.os.Bundle import android.os.Bundle
import android.util.Log
import androidx.activity.ComponentActivity import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge import androidx.activity.enableEdgeToEdge
@@ -21,11 +22,13 @@ import androidx.compose.ui.unit.dp
import androidx.lifecycle.compose.collectAsStateWithLifecycle import androidx.lifecycle.compose.collectAsStateWithLifecycle
import org.mydomain.myscan.ui.theme.MyScanTheme import org.mydomain.myscan.ui.theme.MyScanTheme
import org.mydomain.myscan.view.CameraScreen import org.mydomain.myscan.view.CameraScreen
import org.opencv.android.OpenCVLoader
class MainActivity : ComponentActivity() { class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) { override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState) super.onCreate(savedInstanceState)
initOpenCV()
val viewModel: MainViewModel by viewModels { MainViewModel.getFactory(this) } val viewModel: MainViewModel by viewModels { MainViewModel.getFactory(this) }
enableEdgeToEdge() enableEdgeToEdge()
setContent { setContent {
@@ -43,6 +46,14 @@ class MainActivity : ComponentActivity() {
} }
} }
} }
private fun initOpenCV() {
if (!OpenCVLoader.initLocal()) {
Log.e("OpenCV", "Initialization failed")
} else {
Log.d("OpenCV", "Initialization successful")
}
}
} }
@Composable @Composable

View File

@@ -1,8 +1,6 @@
package org.mydomain.myscan package org.mydomain.myscan
import android.content.Context import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import androidx.camera.core.ImageProxy import androidx.camera.core.ImageProxy
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.ViewModelProvider import androidx.lifecycle.ViewModelProvider
@@ -34,13 +32,15 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
imageSegmentationService.segmentation imageSegmentationService.segmentation
.filterNotNull() .filterNotNull()
.map { .map {
val binaryMask = it.segmentation.toBinaryMask()
UiState( UiState(
"Inference done", detectionMessage = "Inference done",
it.inferenceTime, inferenceTime = it.inferenceTime,
it.segmentation.toBitmap()) binaryMask = binaryMask,
documentQuad = detectDocumentQuad(binaryMask)
)
} }
.collect { .collect {
Log.d("MyScan", "New UIstate ${it}")
_uiState.value = it _uiState.value = it
} }
} }

View File

@@ -7,6 +7,7 @@ import androidx.compose.runtime.Immutable
data class UiState( data class UiState(
val detectionMessage: String? = null, val detectionMessage: String? = null,
val inferenceTime: Long = 0L, val inferenceTime: Long = 0L,
val overlayBitmap: Bitmap? = null, val binaryMask: Bitmap? = null,
val errorMessage: String? = null, val errorMessage: String? = null,
val documentQuad: Quad? = null,
) )

View File

@@ -26,7 +26,12 @@ import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.graphics.BlendMode
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.ColorFilter
import androidx.compose.ui.graphics.asImageBitmap import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.graphics.toArgb
import androidx.compose.ui.platform.LocalConfiguration import androidx.compose.ui.platform.LocalConfiguration
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
@@ -38,6 +43,9 @@ import com.google.common.util.concurrent.ListenableFuture
import org.mydomain.myscan.UiState import org.mydomain.myscan.UiState
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors import java.util.concurrent.Executors
import androidx.core.graphics.scale
import org.mydomain.myscan.Point
import org.mydomain.myscan.scaledTo
@Composable @Composable
fun CameraScreen( fun CameraScreen(
@@ -69,12 +77,7 @@ fun CameraScreen(
.height(height.dp) .height(height.dp)
) { ) {
CameraPreview(onImageAnalyzed = onImageAnalyzed) CameraPreview(onImageAnalyzed = onImageAnalyzed)
if (uiState.overlayBitmap != null) { AnalysisOverlay(uiState)
SegmentationOverlay(
modifier = Modifier.fillMaxSize(),
overlay = uiState.overlayBitmap
)
}
} }
} }
@@ -138,14 +141,49 @@ fun bindCameraUseCases(
} }
@Composable @Composable
fun SegmentationOverlay(modifier: Modifier = Modifier, overlay: Bitmap) { private fun AnalysisOverlay(uiState: UiState) {
Canvas( if (uiState.binaryMask == null) {
modifier = modifier return
) { }
val imageWidth: Float = size.width val maskOverlay = replaceColor(uiState.binaryMask, Color.Black, Color.Transparent)
val imageHeight: Float = size.height Canvas(modifier = Modifier.fillMaxSize()) {
val scaleBitmap = drawImage(
Bitmap.createScaledBitmap(overlay, imageWidth.toInt(), imageHeight.toInt(), true) maskOverlay.scale(size.width.toInt(), size.height.toInt()).asImageBitmap(),
drawImage(scaleBitmap.asImageBitmap()) colorFilter = ColorFilter.tint(Color(0x8000FF00), BlendMode.SrcIn)
)
if (uiState.documentQuad != null) {
val scaledQuad = uiState.documentQuad.scaledTo(
fromWidth = uiState.binaryMask.width,
fromHeight = uiState.binaryMask.height,
toWidth = size.width.toInt(),
toHeight = size.height.toInt()
)
scaledQuad.edges().forEach {
drawLine(Color.Green, it.from.toOffset(), it.to.toOffset(), 5.0f)
} }
} }
}
}
fun replaceColor(bitmap: Bitmap, toReplace: Color, replacement: Color): Bitmap {
val width = bitmap.width
val height = bitmap.height
val result = bitmap.copy(Bitmap.Config.ARGB_8888, true)
val pixels = IntArray(width * height)
result.getPixels(pixels, 0, width, 0, 0, width, height)
val target = toReplace.toArgb()
val newColor = replacement.toArgb()
for (i in pixels.indices) {
if (pixels[i] == target) {
pixels[i] = newColor
}
}
result.setPixels(pixels, 0, width, 0, 0, width, height)
return result
}
fun Point.toOffset() = Offset(x.toFloat(), y.toFloat())

View File

@@ -10,6 +10,7 @@ activityCompose = "1.10.1"
composeBom = "2025.05.00" composeBom = "2025.05.00"
camerax = "1.4.2" camerax = "1.4.2"
litert = "1.2.0" litert = "1.2.0"
opencv = "4.11.0"
[libraries] [libraries]
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
@@ -36,6 +37,7 @@ androidx-camera-view = { group = "androidx.camera", name = "camera-view", versio
litert = { group = "com.google.ai.edge.litert", name = "litert", version.ref = "litert" } litert = { group = "com.google.ai.edge.litert", name = "litert", version.ref = "litert" }
litert-support = { group = "com.google.ai.edge.litert", name = "litert-support", version.ref = "litert" } litert-support = { group = "com.google.ai.edge.litert", name = "litert-support", version.ref = "litert" }
litert-metadata = { group = "com.google.ai.edge.litert", name = "litert-metadata", version.ref = "litert" } litert-metadata = { group = "com.google.ai.edge.litert", name = "litert-metadata", version.ref = "litert" }
opencv = { group="org.opencv", name="opencv", version.ref = "opencv" }
[plugins] [plugins]
android-application = { id = "com.android.application", version.ref = "agp" } android-application = { id = "com.android.application", version.ref = "agp" }