Use LiteRT to analyze an image: display a trivial message as a result
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -13,3 +13,4 @@
|
||||
.externalNativeBuild
|
||||
.cxx
|
||||
local.properties
|
||||
*.tflite
|
||||
|
||||
3
.idea/deploymentTargetSelector.xml
generated
3
.idea/deploymentTargetSelector.xml
generated
@@ -13,6 +13,9 @@
|
||||
</DropdownSelection>
|
||||
<DialogSelection />
|
||||
</SelectionState>
|
||||
<SelectionState runConfigName="MainActivity">
|
||||
<option name="selectionMode" value="DROPDOWN" />
|
||||
</SelectionState>
|
||||
</selectionStates>
|
||||
</component>
|
||||
</project>
|
||||
@@ -43,6 +43,8 @@ dependencies {
|
||||
|
||||
implementation(libs.androidx.core.ktx)
|
||||
implementation(libs.androidx.lifecycle.runtime.ktx)
|
||||
implementation(libs.androidx.lifecycle.runtime.compose)
|
||||
implementation(libs.androidx.lifecycle.viewmodel.compose)
|
||||
implementation(libs.androidx.activity.compose)
|
||||
implementation(platform(libs.androidx.compose.bom))
|
||||
implementation(libs.androidx.ui)
|
||||
@@ -53,6 +55,9 @@ dependencies {
|
||||
implementation(libs.androidx.camera.camera2)
|
||||
implementation(libs.androidx.camera.lifecycle)
|
||||
implementation(libs.androidx.camera.view)
|
||||
implementation(libs.litert)
|
||||
implementation(libs.litert.support)
|
||||
implementation(libs.litert.metadata)
|
||||
|
||||
testImplementation(libs.junit)
|
||||
androidTestImplementation(libs.androidx.junit)
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
package org.mydomain.myscan
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.Color
|
||||
import android.os.SystemClock
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.channels.BufferOverflow
|
||||
import kotlinx.coroutines.flow.MutableSharedFlow
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.SharedFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import org.tensorflow.lite.support.common.FileUtil
|
||||
import org.tensorflow.lite.support.common.ops.NormalizeOp
|
||||
import org.tensorflow.lite.support.image.ColorSpaceType
|
||||
import org.tensorflow.lite.support.image.ImageProcessor
|
||||
import org.tensorflow.lite.support.image.ImageProperties
|
||||
import org.tensorflow.lite.support.image.TensorImage
|
||||
import org.tensorflow.lite.support.image.ops.ResizeOp
|
||||
import org.tensorflow.lite.support.image.ops.Rot90Op
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.FloatBuffer
|
||||
import java.util.Random
|
||||
|
||||
// TODO Review and remove unneeded code
|
||||
class ImageSegmentationService(private val context: Context) {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "ImageSegmentation"
|
||||
}
|
||||
|
||||
private val _segmentation = MutableStateFlow<SegmentationResult?>(null)
|
||||
val segmentation: StateFlow<SegmentationResult?> = _segmentation.asStateFlow()
|
||||
|
||||
val error: SharedFlow<Throwable?>
|
||||
get() = _error
|
||||
private val _error = MutableSharedFlow<Throwable?>()
|
||||
|
||||
private var interpreter: Interpreter? = null
|
||||
|
||||
private val coloredLabels: List<ColoredLabel> = coloredLabels()
|
||||
|
||||
suspend fun initialize() {
|
||||
interpreter = try {
|
||||
val litertBuffer = FileUtil.loadMappedFile(context, "deeplab_v3.tflite")
|
||||
Log.i(TAG, "Loaded LiteRT model")
|
||||
val options = Interpreter.Options().apply {
|
||||
numThreads = 2
|
||||
}
|
||||
Interpreter(litertBuffer, options)
|
||||
} catch (e: Exception) {
|
||||
Log.i(TAG, "Failed to load LiteRT model: ${e.message}")
|
||||
_error.emit(e)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun runSegmentation(bitmap: Bitmap, rotationDegrees: Int) {
|
||||
try {
|
||||
withContext(Dispatchers.IO) {
|
||||
if (interpreter == null) return@withContext
|
||||
val startTime = SystemClock.uptimeMillis()
|
||||
|
||||
val rotation = -rotationDegrees / 90
|
||||
val (_, h, w, _) = interpreter?.getOutputTensor(0)?.shape() ?: return@withContext
|
||||
val imageProcessor =
|
||||
ImageProcessor
|
||||
.Builder()
|
||||
.add(ResizeOp(h, w, ResizeOp.ResizeMethod.BILINEAR))
|
||||
.add(Rot90Op(rotation))
|
||||
.add(NormalizeOp(127.5f, 127.5f))
|
||||
.build()
|
||||
|
||||
// Preprocess the image and convert it into a TensorImage for segmentation.
|
||||
val tensorImage = imageProcessor.process(TensorImage.fromBitmap(bitmap))
|
||||
val segmentResult = segment(tensorImage)
|
||||
val inferenceTime = SystemClock.uptimeMillis() - startTime
|
||||
if (isActive) {
|
||||
_segmentation.value = SegmentationResult(segmentResult, inferenceTime)
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.i(TAG, "Image segmentation error occurred: ${e.message}")
|
||||
_error.emit(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun segment(tensorImage: TensorImage): Segmentation {
|
||||
val (_, h, w, c) = interpreter!!.getOutputTensor(0).shape()
|
||||
val outputBuffer = FloatBuffer.allocate(h * w * c)
|
||||
|
||||
outputBuffer.rewind()
|
||||
interpreter?.run(tensorImage.tensorBuffer.buffer, outputBuffer)
|
||||
|
||||
outputBuffer.rewind()
|
||||
val inferenceData =
|
||||
InferenceData(width = w, height = h, channels = c, buffer = outputBuffer)
|
||||
val mask = processImage(inferenceData)
|
||||
|
||||
val imageProperties =
|
||||
ImageProperties
|
||||
.builder()
|
||||
.setWidth(inferenceData.width)
|
||||
.setHeight(inferenceData.height)
|
||||
.setColorSpaceType(ColorSpaceType.GRAYSCALE)
|
||||
.build()
|
||||
val maskImage = TensorImage()
|
||||
maskImage.load(mask, imageProperties)
|
||||
return Segmentation(
|
||||
listOf(maskImage), coloredLabels
|
||||
)
|
||||
}
|
||||
|
||||
private fun processImage(inferenceData: InferenceData): ByteBuffer {
|
||||
val mask = ByteBuffer.allocateDirect(inferenceData.width * inferenceData.height)
|
||||
for (i in 0 until inferenceData.height) {
|
||||
for (j in 0 until inferenceData.width) {
|
||||
val offset = inferenceData.channels * (i * inferenceData.width + j)
|
||||
|
||||
var maxIndex = 0
|
||||
var maxValue = inferenceData.buffer.get(offset)
|
||||
|
||||
for (index in 1 until inferenceData.channels) {
|
||||
if (inferenceData.buffer.get(offset + index) > maxValue) {
|
||||
maxValue = inferenceData.buffer.get(offset + index)
|
||||
maxIndex = index
|
||||
}
|
||||
}
|
||||
|
||||
mask.put(i * inferenceData.width + j, maxIndex.toByte())
|
||||
}
|
||||
}
|
||||
|
||||
return mask
|
||||
}
|
||||
private fun coloredLabels(): List<ColoredLabel> {
|
||||
val labels = listOf(
|
||||
"background",
|
||||
"aeroplane",
|
||||
"bicycle",
|
||||
"bird",
|
||||
"boat",
|
||||
"bottle",
|
||||
"bus",
|
||||
"car",
|
||||
"cat",
|
||||
"chair",
|
||||
"cow",
|
||||
"dining table",
|
||||
"dog",
|
||||
"horse",
|
||||
"motorbike",
|
||||
"person",
|
||||
"potted plant",
|
||||
"sheep",
|
||||
"sofa",
|
||||
"train",
|
||||
"tv",
|
||||
"------"
|
||||
)
|
||||
val colors = MutableList(labels.size) {
|
||||
ColoredLabel(
|
||||
labels[0], "", Color.BLACK
|
||||
)
|
||||
}
|
||||
|
||||
val random = Random()
|
||||
val goldenRatioConjugate = 0.618033988749895
|
||||
var hue = random.nextDouble()
|
||||
|
||||
// Skip the first label as it's already assigned black
|
||||
for (idx in 1 until labels.size) {
|
||||
hue += goldenRatioConjugate
|
||||
hue %= 1.0
|
||||
// Adjust saturation & lightness as needed
|
||||
val color = Color.HSVToColor(floatArrayOf(hue.toFloat() * 360, 0.7f, 0.8f))
|
||||
colors[idx] = ColoredLabel(labels[idx], "", color)
|
||||
}
|
||||
|
||||
return colors
|
||||
}
|
||||
|
||||
data class Segmentation(
|
||||
val masks: List<TensorImage>,
|
||||
val coloredLabels: List<ColoredLabel>,
|
||||
)
|
||||
|
||||
data class ColoredLabel(val label: String, val displayName: String, val argb: Int)
|
||||
|
||||
data class SegmentationResult(
|
||||
val segmentation: Segmentation,
|
||||
val inferenceTime: Long
|
||||
)
|
||||
|
||||
data class InferenceData(
|
||||
val width: Int,
|
||||
val height: Int,
|
||||
val channels: Int,
|
||||
val buffer: FloatBuffer,
|
||||
)
|
||||
}
|
||||
@@ -5,37 +5,41 @@ import android.util.Log
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.setContent
|
||||
import androidx.activity.enableEdgeToEdge
|
||||
import androidx.activity.viewModels
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.tooling.preview.Preview
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
||||
import org.mydomain.myscan.ui.theme.MyScanTheme
|
||||
import org.mydomain.myscan.view.CameraScreen
|
||||
import java.util.Date
|
||||
|
||||
class MainActivity : ComponentActivity() {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "MyScan"
|
||||
}
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
val viewModel: MainViewModel by viewModels { MainViewModel.getFactory(this) }
|
||||
enableEdgeToEdge()
|
||||
setContent {
|
||||
val uiState by viewModel.uiState.collectAsStateWithLifecycle()
|
||||
Log.d("MyScan", "!!"+uiState.toString())
|
||||
MyScanTheme {
|
||||
Scaffold(/*modifier = Modifier.fillMaxSize()*/) { innerPadding ->
|
||||
Scaffold { innerPadding ->
|
||||
Column {
|
||||
Greeting(modifier = Modifier.padding(innerPadding))
|
||||
Box(/*modifier = Modifier.width(300.dp)*/) {
|
||||
CameraScreen(onImageAnalyzed = { image ->
|
||||
Log.d(TAG, Date().toString())
|
||||
image.close()
|
||||
} )
|
||||
MyMessageBox(uiState.detectionMessage, uiState.inferenceTime)
|
||||
Box {
|
||||
CameraScreen(onImageAnalyzed = { image -> viewModel.segment(image) } )
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -44,6 +48,20 @@ class MainActivity : ComponentActivity() {
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun MyMessageBox(msg: String?, inferenceTime: Long) {
|
||||
Log.d("MyScan", "MyMessageBox recompose: $msg")
|
||||
Text(
|
||||
text = (msg ?: "") + " inferred in " + inferenceTime + "ms",
|
||||
modifier = Modifier
|
||||
.padding(16.dp)
|
||||
.background(Color.Yellow)
|
||||
.fillMaxWidth(),
|
||||
color = Color.Black,
|
||||
fontSize = 20.sp
|
||||
)
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun Greeting(modifier: Modifier = Modifier) {
|
||||
Text(
|
||||
@@ -54,8 +72,8 @@ fun Greeting(modifier: Modifier = Modifier) {
|
||||
|
||||
@Preview(showBackground = true)
|
||||
@Composable
|
||||
fun GreetingPreview() {
|
||||
fun MyMessageBoxPreview() {
|
||||
MyScanTheme {
|
||||
Greeting()
|
||||
MyMessageBox("Found 2 objects!", 42)
|
||||
}
|
||||
}
|
||||
56
app/src/main/java/org/mydomain/myscan/MainViewModel.kt
Normal file
56
app/src/main/java/org/mydomain/myscan/MainViewModel.kt
Normal file
@@ -0,0 +1,56 @@
|
||||
package org.mydomain.myscan
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import androidx.camera.core.ImageProxy
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.ViewModelProvider
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import androidx.lifecycle.viewmodel.CreationExtras
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.filterNotNull
|
||||
import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
class MainViewModel(private val imageSegmentationService: ImageSegmentationService): ViewModel() {
|
||||
|
||||
companion object {
|
||||
fun getFactory(context: Context) = object : ViewModelProvider.Factory {
|
||||
override fun <T : ViewModel> create(modelClass: Class<T>, extras: CreationExtras): T {
|
||||
return MainViewModel(ImageSegmentationService(context)) as T
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var _uiState = MutableStateFlow(UiState("just started"))
|
||||
val uiState: StateFlow<UiState> = _uiState.asStateFlow()
|
||||
|
||||
init {
|
||||
viewModelScope.launch {
|
||||
imageSegmentationService.initialize()
|
||||
imageSegmentationService.segmentation
|
||||
.filterNotNull()
|
||||
.map { UiState("Found ${it.segmentation.masks.size} objects!", it.inferenceTime) }
|
||||
.collect {
|
||||
Log.d("MyScan", "New UIstate ${it}")
|
||||
_uiState.value = it
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fun segment(imageProxy: ImageProxy) {
|
||||
Log.d("MyScan", "MainViewModel.Calling segment")
|
||||
viewModelScope.launch {
|
||||
imageSegmentationService.runSegmentation(
|
||||
imageProxy.toBitmap(),
|
||||
imageProxy.imageInfo.rotationDegrees,
|
||||
)
|
||||
imageProxy.close()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
10
app/src/main/java/org/mydomain/myscan/UiState.kt
Normal file
10
app/src/main/java/org/mydomain/myscan/UiState.kt
Normal file
@@ -0,0 +1,10 @@
|
||||
package org.mydomain.myscan
|
||||
|
||||
import androidx.compose.runtime.Immutable
|
||||
|
||||
@Immutable
|
||||
data class UiState(
|
||||
val detectionMessage: String? = null,
|
||||
val inferenceTime: Long = 0L,
|
||||
val errorMessage: String? = null,
|
||||
)
|
||||
@@ -9,6 +9,7 @@ lifecycleRuntimeKtx = "2.9.0"
|
||||
activityCompose = "1.10.1"
|
||||
composeBom = "2025.05.00"
|
||||
camerax = "1.4.2"
|
||||
litert = "1.2.0"
|
||||
|
||||
[libraries]
|
||||
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
|
||||
@@ -16,6 +17,8 @@ junit = { group = "junit", name = "junit", version.ref = "junit" }
|
||||
androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "junitVersion" }
|
||||
androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espressoCore" }
|
||||
androidx-lifecycle-runtime-ktx = { group = "androidx.lifecycle", name = "lifecycle-runtime-ktx", version.ref = "lifecycleRuntimeKtx" }
|
||||
androidx-lifecycle-viewmodel-compose = { group = "androidx.lifecycle", name = "lifecycle-viewmodel-compose", version.ref = "lifecycleRuntimeKtx" }
|
||||
androidx-lifecycle-runtime-compose = { group = "androidx.lifecycle", name = "lifecycle-runtime-compose", version.ref = "lifecycleRuntimeKtx" }
|
||||
androidx-activity-compose = { group = "androidx.activity", name = "activity-compose", version.ref = "activityCompose" }
|
||||
androidx-compose-bom = { group = "androidx.compose", name = "compose-bom", version.ref = "composeBom" }
|
||||
androidx-ui = { group = "androidx.compose.ui", name = "ui" }
|
||||
@@ -30,6 +33,9 @@ androidx-camera-core = { group = "androidx.camera", name = "camera-core", versio
|
||||
androidx-camera-camera2 = { group = "androidx.camera", name = "camera-camera2", version.ref = "camerax" }
|
||||
androidx-camera-lifecycle = { group = "androidx.camera", name = "camera-lifecycle", version.ref = "camerax" }
|
||||
androidx-camera-view = { group = "androidx.camera", name = "camera-view", version.ref = "camerax" }
|
||||
litert = { group = "com.google.ai.edge.litert", name = "litert", version.ref = "litert" }
|
||||
litert-support = { group = "com.google.ai.edge.litert", name = "litert-support", version.ref = "litert" }
|
||||
litert-metadata = { group = "com.google.ai.edge.litert", name = "litert-metadata", version.ref = "litert" }
|
||||
|
||||
[plugins]
|
||||
android-application = { id = "com.android.application", version.ref = "agp" }
|
||||
|
||||
Reference in New Issue
Block a user