New segmentation model
This commit is contained in:
committed by
pynicolas
parent
e3ada11c8c
commit
0c5a219783
@@ -18,23 +18,24 @@ import android.content.Context
|
|||||||
import android.graphics.Bitmap
|
import android.graphics.Bitmap
|
||||||
import android.graphics.Bitmap.createBitmap
|
import android.graphics.Bitmap.createBitmap
|
||||||
import android.graphics.Color
|
import android.graphics.Color
|
||||||
import android.graphics.Matrix
|
|
||||||
import android.os.SystemClock
|
import android.os.SystemClock
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.core.graphics.scale
|
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.isActive
|
import kotlinx.coroutines.isActive
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
|
import org.tensorflow.lite.DataType
|
||||||
import org.tensorflow.lite.Interpreter
|
import org.tensorflow.lite.Interpreter
|
||||||
import org.tensorflow.lite.support.common.FileUtil
|
import org.tensorflow.lite.support.common.FileUtil
|
||||||
import org.tensorflow.lite.support.image.ColorSpaceType
|
import org.tensorflow.lite.support.common.ops.NormalizeOp
|
||||||
import org.tensorflow.lite.support.image.ImageProperties
|
import org.tensorflow.lite.support.image.ImageProcessor
|
||||||
import org.tensorflow.lite.support.image.TensorImage
|
import org.tensorflow.lite.support.image.TensorImage
|
||||||
|
import org.tensorflow.lite.support.image.ops.ResizeOp
|
||||||
|
import org.tensorflow.lite.support.image.ops.Rot90Op
|
||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
import java.nio.FloatBuffer
|
import java.nio.ByteOrder
|
||||||
|
|
||||||
class ImageSegmentationService(private val context: Context) {
|
class ImageSegmentationService(private val context: Context) {
|
||||||
|
|
||||||
@@ -49,7 +50,7 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
|
|
||||||
fun initialize() {
|
fun initialize() {
|
||||||
interpreter = try {
|
interpreter = try {
|
||||||
val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite")
|
val litertBuffer = FileUtil.loadMappedFile(context, "timm_efficientnet_lite0_quantized.tflite")
|
||||||
Log.i(TAG, "Loaded LiteRT model")
|
Log.i(TAG, "Loaded LiteRT model")
|
||||||
val options = Interpreter.Options().apply {
|
val options = Interpreter.Options().apply {
|
||||||
numThreads = 2
|
numThreads = 2
|
||||||
@@ -64,19 +65,21 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
private fun runSegmentation(interpreter: Interpreter, bitmap: Bitmap, rotationDegrees: Int): SegmentationResult {
|
private fun runSegmentation(interpreter: Interpreter, bitmap: Bitmap, rotationDegrees: Int): SegmentationResult {
|
||||||
val startTime = SystemClock.uptimeMillis()
|
val startTime = SystemClock.uptimeMillis()
|
||||||
|
|
||||||
val (_, _, h, w) = interpreter.getInputTensor(0).shape()
|
val rotation = -rotationDegrees / 90
|
||||||
// Preprocess manually into CHW float buffer
|
val (_, h, w, _) = interpreter.getOutputTensor(0).shape()
|
||||||
val inputBuffer = bitmapToCHWFloatBuffer(bitmap, width = w, height = h, rotationDegrees)
|
val imageProcessor =
|
||||||
|
ImageProcessor
|
||||||
val (_, cOut, hOut, wOut) = interpreter.getOutputTensor(0).shape()
|
.Builder()
|
||||||
val outputBuffer = FloatBuffer.allocate(cOut * hOut * wOut)
|
.add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR))
|
||||||
|
.add(Rot90Op(rotation))
|
||||||
// Run inference
|
.add(NormalizeOp(127.5f, 127.5f)) // TODO check if it's correct
|
||||||
outputBuffer.rewind()
|
.build()
|
||||||
interpreter.run(inputBuffer, outputBuffer)
|
val tensorImage = TensorImage(DataType.FLOAT32)
|
||||||
|
tensorImage.load(bitmap)
|
||||||
|
val processedImage = imageProcessor.process(tensorImage)
|
||||||
|
val segmentResult = segment(interpreter, processedImage)
|
||||||
|
|
||||||
val inferenceTime = SystemClock.uptimeMillis() - startTime
|
val inferenceTime = SystemClock.uptimeMillis() - startTime
|
||||||
val segmentResult = processOutputBuffer(outputBuffer, wOut, hOut, cOut)
|
|
||||||
return SegmentationResult(segmentResult, inferenceTime)
|
return SegmentationResult(segmentResult, inferenceTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,111 +104,40 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun bitmapToCHWFloatBuffer(bitmap: Bitmap, width: Int, height: Int, rotationDegrees: Int): FloatBuffer {
|
private fun segment(interpreter: Interpreter, tensorImage: TensorImage): Segmentation {
|
||||||
val rotatedBitmap = if (rotationDegrees != 0) {
|
val (_, h, w, _) = interpreter.getOutputTensor(0).shape()
|
||||||
val matrix = Matrix().apply { postRotate(rotationDegrees.toFloat()) }
|
val outputBuffer = ByteBuffer.allocateDirect(4 * h * w)
|
||||||
createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true)
|
outputBuffer.order(ByteOrder.nativeOrder())
|
||||||
} else {
|
outputBuffer.rewind()
|
||||||
bitmap
|
interpreter.run(tensorImage.tensorBuffer.buffer, outputBuffer)
|
||||||
}
|
outputBuffer.rewind()
|
||||||
|
val mask = generateMaskFromOutputBuffer(outputBuffer, w, h)
|
||||||
|
return Segmentation(mask)
|
||||||
|
}
|
||||||
|
|
||||||
val resized = rotatedBitmap.scale(width, height)
|
private fun generateMaskFromOutputBuffer(outputBuffer: ByteBuffer, width: Int, height: Int): Bitmap {
|
||||||
val buffer = FloatBuffer.allocate(1 * 3 * height * width)
|
outputBuffer.rewind()
|
||||||
buffer.rewind()
|
val floatArray = FloatArray(width * height)
|
||||||
|
outputBuffer.asFloatBuffer()[floatArray]
|
||||||
val mean = floatArrayOf(0.4611f, 0.4359f, 0.3905f)
|
|
||||||
val std = floatArrayOf(0.2193f, 0.2150f, 0.2109f)
|
|
||||||
|
|
||||||
val pixels = IntArray(width * height)
|
val pixels = IntArray(width * height)
|
||||||
resized.getPixels(pixels, 0, width, 0, 0, width, height)
|
for (i in floatArray.indices) {
|
||||||
|
val value = floatArray[i].coerceIn(0f, 1f)
|
||||||
// Fill buffer in CHW order
|
val gray = (value * 255).toInt()
|
||||||
for (c in 0..2) {
|
pixels[i] = Color.rgb(gray, gray, gray)
|
||||||
for (i in 0 until height) {
|
|
||||||
for (j in 0 until width) {
|
|
||||||
val pixel = pixels[i * width + j]
|
|
||||||
val value = when (c) {
|
|
||||||
0 -> (pixel shr 16 and 0xFF) // R
|
|
||||||
1 -> (pixel shr 8 and 0xFF) // G
|
|
||||||
2 -> (pixel and 0xFF) // B
|
|
||||||
else -> 0
|
|
||||||
}
|
|
||||||
val normalized = (value / 255f - mean[c]) / std[c]
|
|
||||||
buffer.put(normalized)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
buffer.rewind()
|
val bitmap = createBitmap(width, height, Bitmap.Config.ARGB_8888)
|
||||||
return buffer
|
bitmap.setPixels(pixels, 0, width, 0, 0, width, height)
|
||||||
|
return bitmap
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun processOutputBuffer(outputBuffer: FloatBuffer, w: Int, h: Int, c: Int): Segmentation {
|
data class Segmentation(val mask: Bitmap) {
|
||||||
outputBuffer.rewind()
|
fun toBinaryMask(): Bitmap = mask
|
||||||
val inferenceData =
|
|
||||||
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
|
|
||||||
val mask = generateMaskFromOutputBuffer(inferenceData)
|
|
||||||
|
|
||||||
val imageProperties =
|
|
||||||
ImageProperties
|
|
||||||
.builder()
|
|
||||||
.setWidth(inferenceData.width)
|
|
||||||
.setHeight(inferenceData.height)
|
|
||||||
.setColorSpaceType(ColorSpaceType.GRAYSCALE)
|
|
||||||
.build()
|
|
||||||
val maskImage = TensorImage()
|
|
||||||
maskImage.load(mask, imageProperties)
|
|
||||||
return Segmentation(maskImage)
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun generateMaskFromOutputBuffer(inferenceData: InferenceData): ByteBuffer {
|
|
||||||
val width = inferenceData.width
|
|
||||||
val height = inferenceData.height
|
|
||||||
val mask = ByteBuffer.allocateDirect(width * height)
|
|
||||||
for (i in 0 until height) {
|
|
||||||
for (j in 0 until width) {
|
|
||||||
var maxIndex = 0
|
|
||||||
var maxValue = inferenceData.buffer[i * width + j]
|
|
||||||
|
|
||||||
for (c in 1 until inferenceData.channels) {
|
|
||||||
val value = inferenceData.buffer[c * height * width + i * width + j]
|
|
||||||
if (value > maxValue) {
|
|
||||||
maxValue = value
|
|
||||||
maxIndex = c
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mask.put(i * width + j, maxIndex.toByte())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mask
|
|
||||||
}
|
|
||||||
|
|
||||||
data class Segmentation(val mask: TensorImage) {
|
|
||||||
fun toBinaryMask(): Bitmap {
|
|
||||||
val width = mask.width
|
|
||||||
val height = mask.height
|
|
||||||
val pixels = IntArray(width * height)
|
|
||||||
for (i in 0 until height) {
|
|
||||||
for (j in 0 until width) {
|
|
||||||
val index = i * width + j
|
|
||||||
val classId = mask.buffer[index].toInt() and 0xFF // Unsigned byte
|
|
||||||
pixels[index] = if (classId == 0) Color.BLACK else Color.WHITE
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return createBitmap(pixels, width, height, Bitmap.Config.ARGB_8888)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data class SegmentationResult(
|
data class SegmentationResult(
|
||||||
val segmentation: Segmentation,
|
val segmentation: Segmentation,
|
||||||
val inferenceTime: Long
|
val inferenceTime: Long
|
||||||
)
|
)
|
||||||
|
|
||||||
data class InferenceData(
|
|
||||||
val width: Int,
|
|
||||||
val height: Int,
|
|
||||||
val channels: Int,
|
|
||||||
val buffer: FloatBuffer,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,6 +60,8 @@ class MainViewModel(
|
|||||||
private var _pageToValidate = MutableStateFlow<Bitmap?>(null)
|
private var _pageToValidate = MutableStateFlow<Bitmap?>(null)
|
||||||
val pageToValidate: StateFlow<Bitmap?> = _pageToValidate.asStateFlow()
|
val pageToValidate: StateFlow<Bitmap?> = _pageToValidate.asStateFlow()
|
||||||
|
|
||||||
|
var liveAnalysisEnabled = true
|
||||||
|
|
||||||
init {
|
init {
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
imageSegmentationService.initialize()
|
imageSegmentationService.initialize()
|
||||||
@@ -80,6 +82,11 @@ class MainViewModel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun segment(imageProxy: ImageProxy) {
|
fun segment(imageProxy: ImageProxy) {
|
||||||
|
if (!liveAnalysisEnabled) {
|
||||||
|
imageProxy.close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
imageSegmentationService.runSegmentationAndEmit(
|
imageSegmentationService.runSegmentationAndEmit(
|
||||||
imageProxy.toBitmap(),
|
imageProxy.toBitmap(),
|
||||||
|
|||||||
@@ -131,6 +131,7 @@ fun CameraScreen(
|
|||||||
pageCount = viewModel.pageCount(),
|
pageCount = viewModel.pageCount(),
|
||||||
liveAnalysisState,
|
liveAnalysisState,
|
||||||
onCapture = {
|
onCapture = {
|
||||||
|
viewModel.liveAnalysisEnabled = false
|
||||||
showPageDialog.value = true
|
showPageDialog.value = true
|
||||||
isProcessing.value = true
|
isProcessing.value = true
|
||||||
captureController.takePicture(
|
captureController.takePicture(
|
||||||
@@ -138,6 +139,7 @@ fun CameraScreen(
|
|||||||
if (imageProxy != null) {
|
if (imageProxy != null) {
|
||||||
viewModel.processCapturedImageThen(imageProxy) {
|
viewModel.processCapturedImageThen(imageProxy) {
|
||||||
isProcessing.value = false
|
isProcessing.value = false
|
||||||
|
viewModel.liveAnalysisEnabled = true
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Log.e("MyScan", "Error during image capture")
|
Log.e("MyScan", "Error during image capture")
|
||||||
|
|||||||
Reference in New Issue
Block a user