From 0ea4132d371f29158d4f951a1d05b4ffac3b1f3d Mon Sep 17 00:00:00 2001 From: pynicolas <6371790+pynicolas@users.noreply.github.com> Date: Sat, 21 Jun 2025 13:01:05 +0200 Subject: [PATCH] Make gradle download the TFLite segmentation model (#8) --- app/build.gradle.kts | 3 ++ app/download-tflite.gradle.kts | 38 +++++++++++++++++++ .../org/mydomain/myscan/ImageSegmentation.kt | 2 +- 3 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 app/download-tflite.gradle.kts diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 707e133..c91c2e1 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -7,6 +7,7 @@ plugins { android { namespace = "org.mydomain.myscan" compileSdk = 35 + sourceSets["main"].assets.srcDir(layout.buildDirectory.dir("generated/assets")) defaultConfig { applicationId = "org.mydomain.myscan" @@ -43,6 +44,8 @@ android { } } +apply(from = "download-tflite.gradle.kts") + dependencies { implementation(libs.androidx.core.ktx) diff --git a/app/download-tflite.gradle.kts b/app/download-tflite.gradle.kts new file mode 100644 index 0000000..399b94c --- /dev/null +++ b/app/download-tflite.gradle.kts @@ -0,0 +1,38 @@ +import java.net.URL +import org.gradle.api.tasks.Copy + +val modelVersion = "v0.2" +val modelFileName = "document-segmentation-model.tflite" +val modelUrl = "https://github.com/pynicolas/document-segmentation-model/releases/download/$modelVersion/$modelFileName" + +val downloadedModelPath = layout.buildDirectory.file("downloads/$modelFileName") +val generatedAssetsDir = layout.buildDirectory.dir("generated/assets") + +val downloadTFLiteModel = tasks.register("downloadTFLiteModel") { + val outputFile = downloadedModelPath.get().asFile + outputs.file(outputFile) + + doLast { + if (!outputFile.exists()) { + println("Downloading $modelFileName from $modelUrl") + outputFile.parentFile.mkdirs() + URL(modelUrl).openStream().use { input -> + outputFile.outputStream().use { output -> + input.copyTo(output) + } + } + } else { + println("Model already downloaded: ${outputFile.absolutePath}") + } + } +} + +val copyTFLiteToAssets = tasks.register("copyTFLiteToAssets") { + dependsOn(downloadTFLiteModel) + from(downloadedModelPath) + into(generatedAssetsDir) +} + +tasks.named("preBuild") { + dependsOn(copyTFLiteToAssets) +} diff --git a/app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt b/app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt index fb3203a..db4d361 100644 --- a/app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt +++ b/app/src/main/java/org/mydomain/myscan/ImageSegmentation.kt @@ -53,7 +53,7 @@ class ImageSegmentationService(private val context: Context) { fun initialize() { interpreter = try { - val litertBuffer = FileUtil.loadMappedFile(context, "timm_efficientnet_lite0_quantized.tflite") + val litertBuffer = FileUtil.loadMappedFile(context, "document-segmentation-model.tflite") Log.i(TAG, "Loaded LiteRT model") val options = Interpreter.Options().apply { numThreads = 2