Make gradle download the TFLite segmentation model (#8)

This commit is contained in:
pynicolas
2025-06-21 13:01:05 +02:00
committed by GitHub
parent 7056393f56
commit 0ea4132d37
3 changed files with 42 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ plugins {
android { android {
namespace = "org.mydomain.myscan" namespace = "org.mydomain.myscan"
compileSdk = 35 compileSdk = 35
sourceSets["main"].assets.srcDir(layout.buildDirectory.dir("generated/assets"))
defaultConfig { defaultConfig {
applicationId = "org.mydomain.myscan" applicationId = "org.mydomain.myscan"
@@ -43,6 +44,8 @@ android {
} }
} }
apply(from = "download-tflite.gradle.kts")
dependencies { dependencies {
implementation(libs.androidx.core.ktx) implementation(libs.androidx.core.ktx)

View File

@@ -0,0 +1,38 @@
import java.net.URL
import org.gradle.api.tasks.Copy
val modelVersion = "v0.2"
val modelFileName = "document-segmentation-model.tflite"
val modelUrl = "https://github.com/pynicolas/document-segmentation-model/releases/download/$modelVersion/$modelFileName"
val downloadedModelPath = layout.buildDirectory.file("downloads/$modelFileName")
val generatedAssetsDir = layout.buildDirectory.dir("generated/assets")
val downloadTFLiteModel = tasks.register("downloadTFLiteModel") {
val outputFile = downloadedModelPath.get().asFile
outputs.file(outputFile)
doLast {
if (!outputFile.exists()) {
println("Downloading $modelFileName from $modelUrl")
outputFile.parentFile.mkdirs()
URL(modelUrl).openStream().use { input ->
outputFile.outputStream().use { output ->
input.copyTo(output)
}
}
} else {
println("Model already downloaded: ${outputFile.absolutePath}")
}
}
}
val copyTFLiteToAssets = tasks.register<Copy>("copyTFLiteToAssets") {
dependsOn(downloadTFLiteModel)
from(downloadedModelPath)
into(generatedAssetsDir)
}
tasks.named("preBuild") {
dependsOn(copyTFLiteToAssets)
}

View File

@@ -53,7 +53,7 @@ class ImageSegmentationService(private val context: Context) {
fun initialize() { fun initialize() {
interpreter = try { interpreter = try {
val litertBuffer = FileUtil.loadMappedFile(context, "timm_efficientnet_lite0_quantized.tflite") val litertBuffer = FileUtil.loadMappedFile(context, "document-segmentation-model.tflite")
Log.i(TAG, "Loaded LiteRT model") Log.i(TAG, "Loaded LiteRT model")
val options = Interpreter.Options().apply { val options = Interpreter.Options().apply {
numThreads = 2 numThreads = 2