Display a quadrilateral approximation of the detected document
This commit is contained in:
@@ -58,6 +58,7 @@ dependencies {
|
||||
implementation(libs.litert)
|
||||
implementation(libs.litert.support)
|
||||
implementation(libs.litert.metadata)
|
||||
implementation(libs.opencv)
|
||||
|
||||
testImplementation(libs.junit)
|
||||
androidTestImplementation(libs.androidx.junit)
|
||||
|
||||
49
app/src/main/java/org/mydomain/myscan/DocumentDetection.kt
Normal file
49
app/src/main/java/org/mydomain/myscan/DocumentDetection.kt
Normal file
@@ -0,0 +1,49 @@
|
||||
package org.mydomain.myscan
|
||||
|
||||
import android.graphics.Bitmap
|
||||
import org.opencv.android.Utils
|
||||
import org.opencv.core.Mat
|
||||
import org.opencv.core.MatOfPoint
|
||||
import org.opencv.core.MatOfPoint2f
|
||||
import org.opencv.core.Size
|
||||
import org.opencv.imgproc.Imgproc
|
||||
import kotlin.math.abs
|
||||
|
||||
fun detectDocumentQuad(mask: Bitmap): Quad? {
|
||||
val mat = Mat()
|
||||
Utils.bitmapToMat(mask, mat)
|
||||
|
||||
val gray = Mat()
|
||||
Imgproc.cvtColor(mat, gray, Imgproc.COLOR_BGR2GRAY)
|
||||
|
||||
val blurred = Mat()
|
||||
Imgproc.GaussianBlur(gray, blurred, Size(5.0, 5.0), 0.0)
|
||||
|
||||
val edges = Mat()
|
||||
Imgproc.Canny(blurred, edges, 75.0, 200.0)
|
||||
|
||||
val contours = mutableListOf<MatOfPoint>()
|
||||
val hierarchy = Mat()
|
||||
Imgproc.findContours(edges, contours, hierarchy, Imgproc.RETR_LIST, Imgproc.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
var biggest: MatOfPoint2f? = null
|
||||
var maxArea = 0.0
|
||||
|
||||
for (contour in contours) {
|
||||
val contour2f = MatOfPoint2f(*contour.toArray())
|
||||
val peri = Imgproc.arcLength(contour2f, true)
|
||||
val approx = MatOfPoint2f()
|
||||
Imgproc.approxPolyDP(contour2f, approx, 0.02 * peri, true)
|
||||
|
||||
if (approx.total() == 4L) {
|
||||
val area = abs(Imgproc.contourArea(approx))
|
||||
if (area > maxArea) {
|
||||
maxArea = area
|
||||
biggest = approx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val vertices = biggest?.toList()?.map { Point(it.x.toInt(), it.y.toInt()) }
|
||||
return createQuad(vertices)
|
||||
}
|
||||
52
app/src/main/java/org/mydomain/myscan/Geometry.kt
Normal file
52
app/src/main/java/org/mydomain/myscan/Geometry.kt
Normal file
@@ -0,0 +1,52 @@
|
||||
package org.mydomain.myscan
|
||||
|
||||
import kotlin.math.atan2
|
||||
|
||||
data class Point(val x: Int, val y: Int)
|
||||
|
||||
data class Line(val from: Point, val to: Point)
|
||||
|
||||
data class Quad(
|
||||
val topLeft: Point,
|
||||
val topRight: Point,
|
||||
val bottomRight: Point,
|
||||
val bottomLeft: Point
|
||||
) {
|
||||
fun edges(): List<Line> {
|
||||
return listOf(
|
||||
Line(topLeft, topRight),
|
||||
Line(topRight, bottomRight),
|
||||
Line(bottomRight, bottomLeft),
|
||||
Line(bottomLeft, topLeft))
|
||||
}
|
||||
}
|
||||
|
||||
fun createQuad(vertices: List<Point>?): Quad? {
|
||||
if (vertices == null || vertices.size != 4) return null
|
||||
|
||||
// Centroid of the points
|
||||
val cx = vertices.map { it.x }.average()
|
||||
val cy = vertices.map { it.y }.average()
|
||||
|
||||
// Sort by angle from centroid (clockwise)
|
||||
val sorted = vertices.sortedWith(compareBy<Point> {
|
||||
atan2((it.y - cy).toDouble(), (it.x - cx).toDouble())
|
||||
})
|
||||
|
||||
return Quad(sorted[0], sorted[1], sorted[2], sorted[3])
|
||||
}
|
||||
|
||||
fun Quad.scaledTo(fromWidth: Int, fromHeight: Int, toWidth: Int, toHeight: Int): Quad {
|
||||
val scaleX = toWidth.toFloat() / fromWidth
|
||||
val scaleY = toHeight.toFloat() / fromHeight
|
||||
return Quad(
|
||||
topLeft = topLeft.scaled(scaleX, scaleY),
|
||||
topRight = topRight.scaled(scaleX, scaleY),
|
||||
bottomRight = bottomRight.scaled(scaleX, scaleY),
|
||||
bottomLeft = bottomLeft.scaled(scaleX, scaleY)
|
||||
)
|
||||
}
|
||||
|
||||
fun Point.scaled(scaleX: Float, scaleY: Float): Point {
|
||||
return Point((x * scaleX).toInt(), (y * scaleY).toInt())
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package org.mydomain.myscan
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.Bitmap.createBitmap
|
||||
import android.graphics.Color.argb
|
||||
import android.graphics.Color
|
||||
import android.graphics.Matrix
|
||||
import android.os.SystemClock
|
||||
import android.util.Log
|
||||
@@ -168,23 +168,21 @@ class ImageSegmentationService(private val context: Context) {
|
||||
}
|
||||
|
||||
data class Segmentation(val mask: TensorImage) {
|
||||
fun toBitmap(): Bitmap {
|
||||
fun toBinaryMask(): Bitmap {
|
||||
val width = mask.width
|
||||
val height = mask.height
|
||||
val pixels = IntArray(width * height)
|
||||
val green = argb(128, 0, 255, 0)
|
||||
for (i in 0 until height) {
|
||||
for (j in 0 until width) {
|
||||
val index = i * width + j
|
||||
val classId = mask.buffer[index].toInt() and 0xFF // Unsigned byte
|
||||
pixels[index] = if (classId == 0) 0 else green
|
||||
pixels[index] = if (classId == 0) Color.BLACK else Color.WHITE
|
||||
}
|
||||
}
|
||||
return createBitmap(pixels, width, height, Bitmap.Config.ARGB_8888)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
data class SegmentationResult(
|
||||
val segmentation: Segmentation,
|
||||
val inferenceTime: Long
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.mydomain.myscan
|
||||
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.setContent
|
||||
import androidx.activity.enableEdgeToEdge
|
||||
@@ -21,11 +22,13 @@ import androidx.compose.ui.unit.dp
|
||||
import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
||||
import org.mydomain.myscan.ui.theme.MyScanTheme
|
||||
import org.mydomain.myscan.view.CameraScreen
|
||||
import org.opencv.android.OpenCVLoader
|
||||
|
||||
class MainActivity : ComponentActivity() {
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
initOpenCV()
|
||||
val viewModel: MainViewModel by viewModels { MainViewModel.getFactory(this) }
|
||||
enableEdgeToEdge()
|
||||
setContent {
|
||||
@@ -43,6 +46,14 @@ class MainActivity : ComponentActivity() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun initOpenCV() {
|
||||
if (!OpenCVLoader.initLocal()) {
|
||||
Log.e("OpenCV", "Initialization failed")
|
||||
} else {
|
||||
Log.d("OpenCV", "Initialization successful")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package org.mydomain.myscan
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.util.Log
|
||||
import androidx.camera.core.ImageProxy
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.ViewModelProvider
|
||||
@@ -34,13 +32,15 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi
|
||||
imageSegmentationService.segmentation
|
||||
.filterNotNull()
|
||||
.map {
|
||||
val binaryMask = it.segmentation.toBinaryMask()
|
||||
UiState(
|
||||
"Inference done",
|
||||
it.inferenceTime,
|
||||
it.segmentation.toBitmap())
|
||||
detectionMessage = "Inference done",
|
||||
inferenceTime = it.inferenceTime,
|
||||
binaryMask = binaryMask,
|
||||
documentQuad = detectDocumentQuad(binaryMask)
|
||||
)
|
||||
}
|
||||
.collect {
|
||||
Log.d("MyScan", "New UIstate ${it}")
|
||||
_uiState.value = it
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import androidx.compose.runtime.Immutable
|
||||
data class UiState(
|
||||
val detectionMessage: String? = null,
|
||||
val inferenceTime: Long = 0L,
|
||||
val overlayBitmap: Bitmap? = null,
|
||||
val binaryMask: Bitmap? = null,
|
||||
val errorMessage: String? = null,
|
||||
val documentQuad: Quad? = null,
|
||||
)
|
||||
|
||||
@@ -26,7 +26,12 @@ import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.geometry.Offset
|
||||
import androidx.compose.ui.graphics.BlendMode
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.graphics.ColorFilter
|
||||
import androidx.compose.ui.graphics.asImageBitmap
|
||||
import androidx.compose.ui.graphics.toArgb
|
||||
import androidx.compose.ui.platform.LocalConfiguration
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.unit.dp
|
||||
@@ -38,6 +43,9 @@ import com.google.common.util.concurrent.ListenableFuture
|
||||
import org.mydomain.myscan.UiState
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.Executors
|
||||
import androidx.core.graphics.scale
|
||||
import org.mydomain.myscan.Point
|
||||
import org.mydomain.myscan.scaledTo
|
||||
|
||||
@Composable
|
||||
fun CameraScreen(
|
||||
@@ -69,12 +77,7 @@ fun CameraScreen(
|
||||
.height(height.dp)
|
||||
) {
|
||||
CameraPreview(onImageAnalyzed = onImageAnalyzed)
|
||||
if (uiState.overlayBitmap != null) {
|
||||
SegmentationOverlay(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
overlay = uiState.overlayBitmap
|
||||
)
|
||||
}
|
||||
AnalysisOverlay(uiState)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,14 +141,49 @@ fun bindCameraUseCases(
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun SegmentationOverlay(modifier: Modifier = Modifier, overlay: Bitmap) {
|
||||
Canvas(
|
||||
modifier = modifier
|
||||
) {
|
||||
val imageWidth: Float = size.width
|
||||
val imageHeight: Float = size.height
|
||||
val scaleBitmap =
|
||||
Bitmap.createScaledBitmap(overlay, imageWidth.toInt(), imageHeight.toInt(), true)
|
||||
drawImage(scaleBitmap.asImageBitmap())
|
||||
private fun AnalysisOverlay(uiState: UiState) {
|
||||
if (uiState.binaryMask == null) {
|
||||
return
|
||||
}
|
||||
val maskOverlay = replaceColor(uiState.binaryMask, Color.Black, Color.Transparent)
|
||||
Canvas(modifier = Modifier.fillMaxSize()) {
|
||||
drawImage(
|
||||
maskOverlay.scale(size.width.toInt(), size.height.toInt()).asImageBitmap(),
|
||||
colorFilter = ColorFilter.tint(Color(0x8000FF00), BlendMode.SrcIn)
|
||||
)
|
||||
if (uiState.documentQuad != null) {
|
||||
val scaledQuad = uiState.documentQuad.scaledTo(
|
||||
fromWidth = uiState.binaryMask.width,
|
||||
fromHeight = uiState.binaryMask.height,
|
||||
toWidth = size.width.toInt(),
|
||||
toHeight = size.height.toInt()
|
||||
)
|
||||
scaledQuad.edges().forEach {
|
||||
drawLine(Color.Green, it.from.toOffset(), it.to.toOffset(), 5.0f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun replaceColor(bitmap: Bitmap, toReplace: Color, replacement: Color): Bitmap {
|
||||
val width = bitmap.width
|
||||
val height = bitmap.height
|
||||
val result = bitmap.copy(Bitmap.Config.ARGB_8888, true)
|
||||
|
||||
val pixels = IntArray(width * height)
|
||||
result.getPixels(pixels, 0, width, 0, 0, width, height)
|
||||
|
||||
val target = toReplace.toArgb()
|
||||
val newColor = replacement.toArgb()
|
||||
|
||||
for (i in pixels.indices) {
|
||||
if (pixels[i] == target) {
|
||||
pixels[i] = newColor
|
||||
}
|
||||
}
|
||||
|
||||
result.setPixels(pixels, 0, width, 0, 0, width, height)
|
||||
return result
|
||||
}
|
||||
|
||||
fun Point.toOffset() = Offset(x.toFloat(), y.toFloat())
|
||||
|
||||
Reference in New Issue
Block a user