diff --git a/app/src/main/assets/IMG_20250329_132414_900.jpg b/app/src/main/assets/IMG_20250329_132414_900.jpg new file mode 100644 index 0000000..a41867f Binary files /dev/null and b/app/src/main/assets/IMG_20250329_132414_900.jpg differ diff --git a/app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt b/app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt index 653ac23..464b042 100644 --- a/app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt +++ b/app/src/main/java/org/mydomain/myscan/ImageSegmentationService.kt @@ -4,8 +4,10 @@ import android.content.Context import android.graphics.Bitmap import android.graphics.Bitmap.createBitmap import android.graphics.Color.argb +import android.graphics.Matrix import android.os.SystemClock import android.util.Log +import androidx.core.graphics.scale import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableStateFlow @@ -16,13 +18,9 @@ import kotlinx.coroutines.isActive import kotlinx.coroutines.withContext import org.tensorflow.lite.Interpreter import org.tensorflow.lite.support.common.FileUtil -import org.tensorflow.lite.support.common.ops.NormalizeOp import org.tensorflow.lite.support.image.ColorSpaceType -import org.tensorflow.lite.support.image.ImageProcessor import org.tensorflow.lite.support.image.ImageProperties import org.tensorflow.lite.support.image.TensorImage -import org.tensorflow.lite.support.image.ops.ResizeOp -import org.tensorflow.lite.support.image.ops.Rot90Op import java.nio.ByteBuffer import java.nio.FloatBuffer @@ -44,7 +42,7 @@ class ImageSegmentationService(private val context: Context) { suspend fun initialize() { interpreter = try { - val litertBuffer = FileUtil.loadMappedFile(context, "deeplab_v3.tflite") + val litertBuffer = FileUtil.loadMappedFile(context, "mydeeplabv3.tflite") Log.i(TAG, "Loaded LiteRT model") val options = Interpreter.Options().apply { numThreads = 2 @@ -63,21 +61,23 @@ class ImageSegmentationService(private val context: Context) { if (interpreter == null) return@withContext val startTime = SystemClock.uptimeMillis() - val rotation = -rotationDegrees / 90 - val (_, h, w, _) = interpreter?.getOutputTensor(0)?.shape() ?: return@withContext - val imageProcessor = - ImageProcessor - .Builder() - .add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR)) - .add(Rot90Op(rotation)) - .add(NormalizeOp(127.5f, 127.5f)) - .build() + val (_, _, h, w) = interpreter?.getInputTensor(0)?.shape() ?: return@withContext + val dataType = interpreter?.getInputTensor(0)?.dataType() + Log.i(TAG, "segment, input shape: ${interpreter!!.getInputTensor(0).shape().asList()} data type=${dataType}") + + // Preprocess manually into CHW float buffer + val inputBuffer = bitmapToCHWFloatBuffer(bitmap, width = w, height = h, rotationDegrees) + + val (_, cOut, hOut, wOut) = interpreter!!.getOutputTensor(0).shape() + val outputBuffer = FloatBuffer.allocate(cOut * hOut * wOut) + + // Run inference + outputBuffer.rewind() + interpreter?.run(inputBuffer, outputBuffer) - // Preprocess the image and convert it into a TensorImage for segmentation. - val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap)) - val segmentResult = segment(tensorImage) val inferenceTime = SystemClock.uptimeMillis() - startTime if (isActive) { + val segmentResult = processOutputBuffer(outputBuffer, wOut, hOut, cOut) _segmentation.value = SegmentationResult(segmentResult, inferenceTime) } } @@ -87,14 +87,46 @@ class ImageSegmentationService(private val context: Context) { } } + fun bitmapToCHWFloatBuffer(bitmap: Bitmap, width: Int, height: Int, rotationDegrees: Int): FloatBuffer { + val rotatedBitmap = if (rotationDegrees != 0) { + val matrix = Matrix().apply { postRotate(rotationDegrees.toFloat()) } + createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, matrix, true) + } else { + bitmap + } - private fun segment(tensorImage: TensorImage): Segmentation { - val (_, h, w, c) = interpreter!!.getOutputTensor(0).shape() - val outputBuffer = FloatBuffer.allocate(h * w * c) + val resized = rotatedBitmap.scale(width, height) + val buffer = FloatBuffer.allocate(1 * 3 * height * width) + buffer.rewind() - outputBuffer.rewind() - interpreter?.run(tensorImage.tensorBuffer.buffer, outputBuffer) + val mean = floatArrayOf(0.4611f, 0.4359f, 0.3905f) + val std = floatArrayOf(0.2193f, 0.2150f, 0.2109f) + val pixels = IntArray(width * height) + resized.getPixels(pixels, 0, width, 0, 0, width, height) + + // Fill buffer in CHW order + for (c in 0..2) { + for (i in 0 until height) { + for (j in 0 until width) { + val pixel = pixels[i * width + j] + val value = when (c) { + 0 -> (pixel shr 16 and 0xFF) // R + 1 -> (pixel shr 8 and 0xFF) // G + 2 -> (pixel and 0xFF) // B + else -> 0 + } + val normalized = (value / 255f - mean[c]) / std[c] + buffer.put(normalized) + } + } + } + + buffer.rewind() + return buffer + } + + private fun processOutputBuffer(outputBuffer: FloatBuffer, w: Int, h: Int, c: Int): Segmentation { outputBuffer.rewind() val inferenceData = InferenceData(width = w, height = h, channels = c, buffer = outputBuffer) @@ -116,15 +148,15 @@ class ImageSegmentationService(private val context: Context) { val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height) for (i in 0 until inferenceData.height) { for (j in 0 until inferenceData.width) { - val offset = inferenceData.channels * (i * inferenceData.width + j) - var maxIndex = 0 - var maxValue = inferenceData.buffer.get(offset) + var maxValue = inferenceData.buffer.get(i * inferenceData.width + j) - for (index in 1 until inferenceData.channels) { - if (inferenceData.buffer.get(offset + index) > maxValue) { - maxValue = inferenceData.buffer.get(offset + index) - maxIndex = index + for (c in 1 until inferenceData.channels) { + val value = inferenceData.buffer.get( + c * inferenceData.height * inferenceData.width + i * inferenceData.width + j) + if (value > maxValue) { + maxValue = value + maxIndex = c } } @@ -164,4 +196,4 @@ class ImageSegmentationService(private val context: Context) { val channels: Int, val buffer: FloatBuffer, ) -} \ No newline at end of file +} diff --git a/app/src/main/java/org/mydomain/myscan/MainViewModel.kt b/app/src/main/java/org/mydomain/myscan/MainViewModel.kt index ce8bbfd..1caec79 100644 --- a/app/src/main/java/org/mydomain/myscan/MainViewModel.kt +++ b/app/src/main/java/org/mydomain/myscan/MainViewModel.kt @@ -1,6 +1,7 @@ package org.mydomain.myscan import android.content.Context +import android.graphics.Bitmap import android.util.Log import androidx.camera.core.ImageProxy import androidx.lifecycle.ViewModel @@ -34,7 +35,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi .filterNotNull() .map { UiState( - "Found ${numberOfObjectsDetected(it.segmentation)} objects!", + "Inference done", it.inferenceTime, it.segmentation.toBitmap()) } @@ -45,18 +46,7 @@ class MainViewModel(private val imageSegmentationService: ImageSegmentationServi } } - fun numberOfObjectsDetected(segmentation: ImageSegmentationService.Segmentation) : Int { - val tensor = segmentation.mask; - val buffer = tensor.buffer - val uniqueValues = HashSet() - for (i in 0..tensor.width * tensor.height - 1) { - uniqueValues.add(buffer[i].toInt()) - } - return uniqueValues.size - 1; - } - fun segment(imageProxy: ImageProxy) { - Log.d("MyScan", "MainViewModel.Calling segment") viewModelScope.launch { imageSegmentationService.runSegmentation( imageProxy.toBitmap(),