Make gradle download the TFLite segmentation model (#8)
This commit is contained in:
@@ -7,6 +7,7 @@ plugins {
|
|||||||
android {
|
android {
|
||||||
namespace = "org.mydomain.myscan"
|
namespace = "org.mydomain.myscan"
|
||||||
compileSdk = 35
|
compileSdk = 35
|
||||||
|
sourceSets["main"].assets.srcDir(layout.buildDirectory.dir("generated/assets"))
|
||||||
|
|
||||||
defaultConfig {
|
defaultConfig {
|
||||||
applicationId = "org.mydomain.myscan"
|
applicationId = "org.mydomain.myscan"
|
||||||
@@ -43,6 +44,8 @@ android {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
apply(from = "download-tflite.gradle.kts")
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
|
||||||
implementation(libs.androidx.core.ktx)
|
implementation(libs.androidx.core.ktx)
|
||||||
|
|||||||
38
app/download-tflite.gradle.kts
Normal file
38
app/download-tflite.gradle.kts
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import java.net.URL
|
||||||
|
import org.gradle.api.tasks.Copy
|
||||||
|
|
||||||
|
val modelVersion = "v0.2"
|
||||||
|
val modelFileName = "document-segmentation-model.tflite"
|
||||||
|
val modelUrl = "https://github.com/pynicolas/document-segmentation-model/releases/download/$modelVersion/$modelFileName"
|
||||||
|
|
||||||
|
val downloadedModelPath = layout.buildDirectory.file("downloads/$modelFileName")
|
||||||
|
val generatedAssetsDir = layout.buildDirectory.dir("generated/assets")
|
||||||
|
|
||||||
|
val downloadTFLiteModel = tasks.register("downloadTFLiteModel") {
|
||||||
|
val outputFile = downloadedModelPath.get().asFile
|
||||||
|
outputs.file(outputFile)
|
||||||
|
|
||||||
|
doLast {
|
||||||
|
if (!outputFile.exists()) {
|
||||||
|
println("Downloading $modelFileName from $modelUrl")
|
||||||
|
outputFile.parentFile.mkdirs()
|
||||||
|
URL(modelUrl).openStream().use { input ->
|
||||||
|
outputFile.outputStream().use { output ->
|
||||||
|
input.copyTo(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
println("Model already downloaded: ${outputFile.absolutePath}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val copyTFLiteToAssets = tasks.register<Copy>("copyTFLiteToAssets") {
|
||||||
|
dependsOn(downloadTFLiteModel)
|
||||||
|
from(downloadedModelPath)
|
||||||
|
into(generatedAssetsDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.named("preBuild") {
|
||||||
|
dependsOn(copyTFLiteToAssets)
|
||||||
|
}
|
||||||
@@ -53,7 +53,7 @@ class ImageSegmentationService(private val context: Context) {
|
|||||||
|
|
||||||
fun initialize() {
|
fun initialize() {
|
||||||
interpreter = try {
|
interpreter = try {
|
||||||
val litertBuffer = FileUtil.loadMappedFile(context, "timm_efficientnet_lite0_quantized.tflite")
|
val litertBuffer = FileUtil.loadMappedFile(context, "document-segmentation-model.tflite")
|
||||||
Log.i(TAG, "Loaded LiteRT model")
|
Log.i(TAG, "Loaded LiteRT model")
|
||||||
val options = Interpreter.Options().apply {
|
val options = Interpreter.Options().apply {
|
||||||
numThreads = 2
|
numThreads = 2
|
||||||
|
|||||||
Reference in New Issue
Block a user