diff --git a/app/src/main/java/org/fairscan/app/domain/ImageSegmentation.kt b/app/src/main/java/org/fairscan/app/domain/ImageSegmentation.kt index 149669f..97e922e 100644 --- a/app/src/main/java/org/fairscan/app/domain/ImageSegmentation.kt +++ b/app/src/main/java/org/fairscan/app/domain/ImageSegmentation.kt @@ -61,7 +61,7 @@ class ImageSegmentationService(private val context: Context, private val logger: } } - private fun runSegmentation(interpreter: Interpreter, bitmap: Bitmap, rotationDegrees: Int): SegmentationResult { + private fun runSegmentation(interpreter: Interpreter, bitmap: Bitmap): SegmentationResult { val startTime = SystemClock.uptimeMillis() val (_, h, w, _) = interpreter.getOutputTensor(0).shape() @@ -77,19 +77,15 @@ class ImageSegmentationService(private val context: Context, private val logger: val segmentResult = segment(interpreter, processedImage) val inferenceTime = SystemClock.uptimeMillis() - startTime - return SegmentationResult( - segmentResult, - ImageSize(bitmap.width, bitmap.height), - rotationDegrees, - inferenceTime) + return SegmentationResult(segmentResult, inferenceTime) } - suspend fun runSegmentationAndReturn(bitmap: Bitmap, rotationDegrees: Int): SegmentationResult? { + suspend fun runSegmentationAndReturn(bitmap: Bitmap): SegmentationResult? { if (interpreter == null) { return null } return inferenceLock.withLock { - runSegmentation(interpreter!!, bitmap, rotationDegrees) + runSegmentation(interpreter!!, bitmap) } } @@ -149,8 +145,6 @@ class ImageSegmentationService(private val context: Context, private val logger: data class SegmentationResult( val segmentation: Segmentation, - val originalSize: ImageSize, - val rotationDegrees: Int, val inferenceTime: Long ) } diff --git a/app/src/main/java/org/fairscan/app/ui/screens/camera/CameraUiState.kt b/app/src/main/java/org/fairscan/app/ui/screens/camera/CameraUiState.kt index 0dd9c25..79f05f0 100644 --- a/app/src/main/java/org/fairscan/app/ui/screens/camera/CameraUiState.kt +++ b/app/src/main/java/org/fairscan/app/ui/screens/camera/CameraUiState.kt @@ -24,7 +24,6 @@ data class LiveAnalysisState( val inferenceTime: Long = 0L, val maskSize: ImageSize? = null, val binaryMaskProvider: () -> Bitmap? = { -> null }, - val documentQuad: Quad? = null, val stableQuad: Quad? = null, ) diff --git a/app/src/main/java/org/fairscan/app/ui/screens/camera/CameraViewModel.kt b/app/src/main/java/org/fairscan/app/ui/screens/camera/CameraViewModel.kt index 71f1626..545e452 100644 --- a/app/src/main/java/org/fairscan/app/ui/screens/camera/CameraViewModel.kt +++ b/app/src/main/java/org/fairscan/app/ui/screens/camera/CameraViewModel.kt @@ -33,6 +33,7 @@ import org.fairscan.app.domain.CapturedPage import org.fairscan.app.domain.ExportQuality import org.fairscan.app.domain.PageMetadata import org.fairscan.app.domain.Rotation +import org.fairscan.imageprocessing.ImageSize import org.fairscan.imageprocessing.Mask import org.fairscan.imageprocessing.Quad import org.fairscan.imageprocessing.detectDocumentQuad @@ -100,30 +101,23 @@ class CameraViewModel(appContainer: AppContainer): ViewModel() { } viewModelScope.launch { + val rotationDegrees = imageProxy.imageInfo.rotationDegrees val result = withContext(Dispatchers.IO) { - imageSegmentationService.runSegmentationAndReturn( - imageProxy.toBitmap(), - imageProxy.imageInfo.rotationDegrees, - ) + imageSegmentationService.runSegmentationAndReturn(imageProxy.toBitmap()) } result?.let { + val segmentation = result.segmentation + val maskSize = segmentation.maskSize() + val originalSize = ImageSize(imageProxy.width, imageProxy.height) val rawQuad = withContext(Dispatchers.Default) { - detectDocumentQuad( - result.segmentation, - result.originalSize, - isLiveAnalysis = true - )?.rotate90( - result.rotationDegrees / 90, - result.segmentation.width, - result.segmentation.height - ) + detectDocumentQuad(segmentation, originalSize, isLiveAnalysis = true) + ?.rotate90(rotationDegrees / 90, maskSize) } val binaryMaskProvider = { -> - var binaryMask: Bitmap = result.segmentation.toBinaryMask() - if (result.rotationDegrees != 0) { - binaryMask = - rotateBitmap(binaryMask, result.rotationDegrees.toFloat()) + var binaryMask: Bitmap = segmentation.toBinaryMask() + if (rotationDegrees != 0) { + binaryMask = rotateBitmap(binaryMask, rotationDegrees.toFloat()) } binaryMask } @@ -131,8 +125,7 @@ class CameraViewModel(appContainer: AppContainer): ViewModel() { _liveAnalysisState.value = LiveAnalysisState( inferenceTime = result.inferenceTime, binaryMaskProvider = binaryMaskProvider, - maskSize = result.segmentation.maskSize(), - documentQuad = rawQuad, + maskSize = maskSize, stableQuad = stableQuad, ) } @@ -164,10 +157,11 @@ class CameraViewModel(appContainer: AppContainer): ViewModel() { rotationDegrees: Int, ): CapturedPage? = withContext(Dispatchers.IO) { var result: CapturedPage? = null - val segmentation = imageSegmentationService.runSegmentationAndReturn(source, 0) + val segmentation = imageSegmentationService.runSegmentationAndReturn(source) if (segmentation != null) { val mask = segmentation.segmentation - val quad = detectDocumentQuad(mask, segmentation.originalSize, isLiveAnalysis = false) + val originalSize = ImageSize(source.width, source.height) + val quad = detectDocumentQuad(mask, originalSize, isLiveAnalysis = false) if (quad != null) { val resizedQuad = quad.scaledTo(mask.width, mask.height, source.width, source.height) result = extractDocumentFromBitmap(source, resizedQuad, rotationDegrees, mask) diff --git a/imageprocessing/src/main/java/org/fairscan/imageprocessing/Geometry.kt b/imageprocessing/src/main/java/org/fairscan/imageprocessing/Geometry.kt index e4063d4..c4cea6a 100644 --- a/imageprocessing/src/main/java/org/fairscan/imageprocessing/Geometry.kt +++ b/imageprocessing/src/main/java/org/fairscan/imageprocessing/Geometry.kt @@ -47,16 +47,18 @@ data class Quad( Line(bottomLeft, topLeft)) } - fun rotate90(iterations: Int, imageWidth: Int, imageHeight: Int): Quad { + fun rotate90(iterations: Int, imageSize: ImageSize): Quad { val rotatedPoints = listOf( - rotate90(topLeft, imageWidth, imageHeight, iterations), - rotate90(topRight, imageWidth, imageHeight, iterations), - rotate90(bottomRight, imageWidth, imageHeight, iterations), - rotate90(bottomLeft, imageWidth, imageHeight, iterations) + rotate90(topLeft, imageSize, iterations), + rotate90(topRight, imageSize, iterations), + rotate90(bottomRight, imageSize, iterations), + rotate90(bottomLeft, imageSize, iterations) ) return createQuad(rotatedPoints) } - private fun rotate90(p: Point, width: Int, height: Int, iterations: Int): Point { + private fun rotate90(p: Point, imageSize: ImageSize, iterations: Int): Point { + val width = imageSize.width + val height = imageSize.height return when (iterations % 4) { 1 -> Point(height - p.y, p.x) // 90° 2 -> Point(width - p.x, height - p.y) // 180° diff --git a/imageprocessing/src/test/java/org/fairscan/imageprocessing/GeometryTest.kt b/imageprocessing/src/test/java/org/fairscan/imageprocessing/GeometryTest.kt index 4b68602..d7c7ff6 100644 --- a/imageprocessing/src/test/java/org/fairscan/imageprocessing/GeometryTest.kt +++ b/imageprocessing/src/test/java/org/fairscan/imageprocessing/GeometryTest.kt @@ -40,21 +40,21 @@ class GeometryTest { fun rotateQuad() { val quad = createQuad(listOf( Point(1,2), Point(10, 3), Point(11,12), Point(3, 9))) - assertThat(quad.rotate90(1, 100, 50)).isEqualTo( + assertThat(quad.rotate90(1, ImageSize(100, 50))).isEqualTo( createQuad(listOf( Point(48,1), Point(47, 10), Point(38,11), Point(41, 3) ))) - assertThat(quad.rotate90(2, 100, 50)).isEqualTo( + assertThat(quad.rotate90(2, ImageSize(100, 50))).isEqualTo( createQuad(listOf( Point(99,48), Point(90, 47), Point(89,38), Point(97, 41) ))) - assertThat(quad.rotate90(3, 100, 50)).isEqualTo( + assertThat(quad.rotate90(3, ImageSize(100, 50))).isEqualTo( createQuad(listOf( Point(2,99), Point(3, 90), Point(12,89), Point(9, 97) ))) - assertThat(quad.rotate90(4, 100, 50)).isEqualTo(quad) - assertThat(quad.rotate90(5, 100, 50)).isEqualTo( - quad.rotate90(1, 100, 50) + assertThat(quad.rotate90(4, ImageSize(100, 50))).isEqualTo(quad) + assertThat(quad.rotate90(5, ImageSize(100, 50))) + .isEqualTo(quad.rotate90(1, ImageSize(100, 50)) ) } }