Run segmentation and quad detection in sequence to avoid CPU contention

This commit is contained in:
Pierre-Yves Nicolas
2026-03-08 14:22:59 +01:00
parent 343495dafe
commit 4e3cc95979
2 changed files with 37 additions and 57 deletions

View File

@@ -20,14 +20,8 @@ import android.graphics.Bitmap.createBitmap
import android.graphics.Color
import android.os.SystemClock
import android.util.Log
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.isActive
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import org.fairscan.app.data.Logger
import org.fairscan.imageprocessing.ImageSize
import org.fairscan.imageprocessing.Mask
@@ -49,9 +43,6 @@ class ImageSegmentationService(private val context: Context, private val logger:
private const val TAG = "ImageSegmentation"
}
private val _segmentation = MutableStateFlow<SegmentationResult?>(null)
val segmentation: StateFlow<SegmentationResult?> = _segmentation.asStateFlow()
private var interpreter: Interpreter? = null
private val inferenceLock = Mutex()
@@ -102,19 +93,6 @@ class ImageSegmentationService(private val context: Context, private val logger:
}
}
suspend fun runSegmentationAndEmit(bitmap: Bitmap, rotationDegrees: Int) {
try {
withContext(Dispatchers.IO) {
val segmentationResult = runSegmentationAndReturn(bitmap, rotationDegrees)
if (isActive) {
_segmentation.value = segmentationResult
}
}
} catch (e: Exception) {
logger.e(TAG, "Error occurred in image segmentation", e)
}
}
private fun segment(interpreter: Interpreter, tensorImage: TensorImage): Segmentation {
val (_, h, w, _) = interpreter.getOutputTensor(0).shape()
val outputBuffer = ByteBuffer.allocateDirect(4 * h * w)

View File

@@ -26,7 +26,6 @@ import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.fairscan.app.AppContainer
@@ -70,36 +69,6 @@ class CameraViewModel(appContainer: AppContainer): ViewModel() {
init {
viewModelScope.launch {
imageSegmentationService.initialize()
imageSegmentationService.segmentation
.filterNotNull()
.collect { result ->
val binaryMaskProvider = { ->
var binaryMask: Bitmap = result.segmentation.toBinaryMask()
if (result.rotationDegrees != 0) {
binaryMask = rotateBitmap(binaryMask, result.rotationDegrees.toFloat())
}
binaryMask
}
val rawQuad = detectDocumentQuad(
result.segmentation,
result.originalSize,
isLiveAnalysis = true
)?.rotate90(
result.rotationDegrees / 90,
result.segmentation.width,
result.segmentation.height
)
val stableQuad = quadStabilizer.update(rawQuad)
_liveAnalysisState.value = LiveAnalysisState(
inferenceTime = result.inferenceTime,
binaryMaskProvider = binaryMaskProvider,
maskSize = result.segmentation.maskSize(),
documentQuad = rawQuad,
stableQuad = stableQuad,
)
}
}
}
@@ -131,10 +100,43 @@ class CameraViewModel(appContainer: AppContainer): ViewModel() {
}
viewModelScope.launch {
imageSegmentationService.runSegmentationAndEmit(
imageProxy.toBitmap(),
imageProxy.imageInfo.rotationDegrees,
)
val result = withContext(Dispatchers.IO) {
imageSegmentationService.runSegmentationAndReturn(
imageProxy.toBitmap(),
imageProxy.imageInfo.rotationDegrees,
)
}
result?.let {
val rawQuad = withContext(Dispatchers.Default) {
detectDocumentQuad(
result.segmentation,
result.originalSize,
isLiveAnalysis = true
)?.rotate90(
result.rotationDegrees / 90,
result.segmentation.width,
result.segmentation.height
)
}
val binaryMaskProvider = { ->
var binaryMask: Bitmap = result.segmentation.toBinaryMask()
if (result.rotationDegrees != 0) {
binaryMask =
rotateBitmap(binaryMask, result.rotationDegrees.toFloat())
}
binaryMask
}
val stableQuad = quadStabilizer.update(rawQuad)
_liveAnalysisState.value = LiveAnalysisState(
inferenceTime = result.inferenceTime,
binaryMaskProvider = binaryMaskProvider,
maskSize = result.segmentation.maskSize(),
documentQuad = rawQuad,
stableQuad = stableQuad,
)
}
imageProxy.close()
}
}