Android Studio

LiteRT : Android 온디바이스 AI 비교부터 MNIST 구현까지

혁준 2026. 1. 12. 14:56

 

1. 온디바이스 소개 

AI 활용이 만연해지면서 현재 많은 App들이 서비스에 AI를 도입하고 있습니다. 사진, 음성 같은 비정형 데이터를 분석해 주고 사용자 개인의 취향, 사용패턴, 상황에 맞춰 정교한 추천이나 알림이 가능해졌습니다. 이로 인해 사용자들을 App을 좀 더 편리하게 사용할 수 있게 되었습니다.

 

AI를 도입한다고 하면, 대부분 클라우드 AI를 생각할 것입니다. 클라우드 AI는 App(클라이언트)에서 생성된 데이터를 클라우드(서버 측 인프라)로 전송하고, 그곳에서 AI 모델이 연산을 수행한 뒤 결과를 앱에 반환합니다.

 

이렇게 클라우드 AI를 사용하면 대형 AI 모델 사용이 가능하기에 뛰어난 성능을 내기 용이하다는 명확한 장점이 존재합니다.

그렇지만 단점도 존재합니다. 클라우드 AI를 사용하면 비용적인 측면에서 부담이 될 수도 있으며, 민감한 개인정보에 대해 외부로 유출될 수도 있다는 위험이 존재합니다. 

 

이러한 클라우드 AI의 단점을 온디바이스 AI를 통해 해결할 수 있습니다. 

온디바이스 AI는 모델 추론이 사용자 기기 내부에서 수행됩니다. 입력 데이터가 기기를 벗어나지 않고, 디바이스의 CPU/GPU/NPU를 활용해 결과를 계산합니다. (대표적인 예시로는 오프라인 번역, 카메라 실시간 인식 등이 있습니다.)

 

이러한 온디바이스 AI를 사용함으로써 얻을 수 있는 이점이 있습니다. 

(1) 디바이스 내부에서만 동작하기 때문에 네트워크 연결이 필요가 없습니다. 그렇기에 네트워크 통신에 대한 비용이 발생하지 않으며, 통신 관련 이슈가 발생하지 않습니다. 이외에도 오프라인에서도 동작이 가능하기 때문에 다양한 환경에 제약을 받지 않고 사용이 가능해집니다. 

(2) 기기 내부에서 모델을 사용하기 때문에 개인 정보가 외부로 유출될 확률이 감소한다는 이점이 있습니다. 그렇지만 모델이 App에 포함되어 있는 경우 해당 모델에 대한 정보가 유출될 수 있기에 주의가 필요합니다. 

(3) 개인화된 기기를 기반으로 사용하기 때문에 추가적인 비용 또한 발생하지 않습니다

 

그렇지만 온디바이스 AI 또한 단점이 존재합니다. 

(1) App 내부에 모델이 포함되어 있는 경우에 앱 용량이 증가하게 되며, 모델을 업데이트하는 경우에 클라우드 AI는 서버에서 모델을 교체하면 즉시 반영이 가능하지만, 온디바이스 경우에는 App을 추가적으로 배포해줘야 합니다. 

(2) 현재 디바이스의 성능이 많이 좋아졌지만 그럼에도 불구하고 아직은 클라우드 AI와 비교했을 때 성능 및 품질적으로 부족합니다. 

 

그렇기에 이러한 요소들을 고려하여 현재 서비스 중인 자신의 App에 적합한 AI를 선택하는 것이 중요하다고 생각됩니다. 

 

 

 

2. ML Kit GenAI, Gemini Nano+AI Core, LiteRT, MediaPipe

Android에서 온디바이스 AI를 구현하는 길은 크게 두 갈래로 나뉩니다. 하나는 AICore 시스템 서비스를 통해 실행되는 Gemini Nano를 기반으로, ML Kit GenAI처럼 고수준 API로 요약·교정·재작성 등의 기능을 빠르게 사용하는 방식이고, 다른 하나는 Google AI Edge의 LiteRT과 MediaPipe를 활용해 커스텀 모델/실시간 파이프라인을 직접 구성하는 방식입니다.

 

(1) ML Kit GenAI API 

ML Kit GenAI는 ML Kit 안에 포함된 온디바이스 생성형 AI(GenAI)용 API입니다. 개발자는 요약, 교정, 재작성, 이미지 설명과 같은 AI 기능을 즉시 호출할 수 있게 해 줍니다. 내부적으로는 AI Core 위에서 실행되는 Gemini Nano를 활용해 온디바이스 GenAI를 제공합니다.

 

(2) Gemini nano 

온디바이스에서 동작하는 기초모델입니다. AI Core 위에서 실행됩니다. AICore는 휴대폰의 CPU/GPU/NPU 같은 하드웨어를 활용해 빠르게 AI가 동작하도록 돕고, Gemini Nano 모델과 필요한 파일들을 시스템 업데이트를 통해 자동으로 업데이트해 줍니다. 그래서 앱 개발자가 모델을 따로 내려받게 하거나 버전을 직접 관리하지 않아도 됩니다.

 

(3) Media Pipe

카메라/오디오 같은 스트리밍 입력을 받아 전처리–추론–후처리–렌더링까지 연결하는 실시간 파이프라인 구성에 강점이 있으며, 포즈/랜드마크 추정처럼 “프레임 단위 처리”가 핵심인 문제에 특히 적합합니다.

 

(4) LiteRT

커스텀 모델을 앱에 탑재하여 직접 추론하는 런타임으로, 입력 텐서 설계·전처리·후처리·가속기 선택을 포함해 구현 통제력이 가장 높습니다.

 

 

3. LiteRT 추가 설명

위에서 설명한 다양한 온디바이스 AI 중 LiteRT에 대해 좀 더 알아보고자 합니다. 

LiteRT는 Android 등 엣지(온디바이스) 환경에서 머신러닝 모델을 빠르고 효율적으로 추론하기 위한 런타임입니다. 이전에는 TensorFlow Lite(TFLite) 불렸었고, 현재는 LiteRT로 명칭이 변경되었습니다. 

 

Android에서 LiteRT를 사용함으로써, 커스텀 모델을 앱에서 직접 실행할 수 있습니다. 또한 CPU 뿐만 아니라 GPU, NPU를 적극 활용하여 성능을 끌어올릴 수 있다는 이점이 있습니다. 

CPU는 대부분의 기기에서 무난하게 잘 동작하고, 지원하는 연산도 폭넓어서 일단 확실히 돌아가게 하는 기본 경로로 쓰기 좋습니다. GPU는 여러 계산을 동시에 처리하는 데 강해서, 컨볼루션이나 행렬곱처럼 병렬 처리가 잘 되는 연산에서는 속도가 더 빨라지는 경우가 많아 이미지,텐서 추론에 자주 활용됩니다. NPU는 기기마다 탑재 여부와 성능이 다르며, 실제로 얼마나 빨라질지는 해당 기기의 드라이버/가속기 지원 상태와 모델이 사용하는 연산 및 정밀도가 호환되는지에 따라 크게 달라집니다. 

 

 

이제 LiteRT를 사용해서 mnist 예제를 구현해 보도록 하겠습니다.

 

 

4. LiteRT Android에서 구현해보기 (mnist)

LiteRT는 모델을 직접 커스텀해서 사용할 수 있다는 장점이 있습니다. 그렇기에 mnist 모델이 필요합니다. 

https://ai.google.dev/edge/litert/conversion/tensorflow/quantization/post_training_quant?utm_source=chatgpt.com

 

Post-training dynamic range quantization  |  Google AI Edge  |  Google AI for Developers

Send feedback Post-training dynamic range quantization Copyright 2024 The AI Edge Authors. Licensed under the Apache License, Version 2.0 (the "License"); Toggle code # you may not use this file except in compliance with the License. # You may obtain a cop

ai.google.dev

저는 해당 페이지의 코드를 사용해서 모델을 생성하였습니다.

 

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_data=(test_images, test_labels)
)

keras에서 제공하는 MNIST 데이터셋 로더를 사용합니다. 

MNIST는 손글씨 숫자(0~9)를 분류해 주는 기능을 수행합니다.

 

상단의 코드에서는 MNIST라는 손글씨 숫자 이미지 데이터를 불러온 다음, 이미지 픽셀 값을 0~255에서 0~1 범위로 바꿔 학습하기 쉬운 형태로 전처리합니다. 그 후에 Conv2D와 MaxPooling으로 숫자 모양의 특징을 뽑아내고, 이를 펼쳐서(Flatten) 마지막에 10개 출력(Dense(10))으로 0부터 9까지 각 숫자에 대한 점수를 계산하는 간단한 CNN 모델을 만듭니다.

마지막으로 Adam 옵티마이저와 분류용 손실 함수를 설정해 모델을 학습시키며, 학습 데이터 전체를 한 번(epochs=1) 돌려 훈련하고, 동시에 테스트 데이터로 성능도 함께 확인합니다.

 

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

 

해당 코드를 사용하여 Keras 모델 객체를 TFLite 바이트로 생성합니다. 

 

tflite_models_dir = pathlib.Path("/tmp/mnist_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = tflite_models_dir / "mnist_model.tflite"
tflite_model_file.write_bytes(tflite_model)

 

그 후에 설정한 경로에 파일로 저장을 합니다. 

 

이후에 상단의 이미지와 같이 Asset Folder를 추가한 후, 생성해 둔 .tflite 파일을 추가합니다. 

그 후에 Android Studio에서 코드를 작성하면 됩니다.

 

우선 LiteRT를 사용하기 위해 dependencies에 하단의 링크를 참고해서 라이브러리를 추가해 줍니다.
https://mvnrepository.com/artifact/com.google.ai.edge.litert/litert

https://mvnrepository.com/artifact/io.coil-kt.coil3/coil

 

 

 

/**
 * LiteRT 기반 MNIST 숫자 분류기
 */
class MnistDigitClassifier(
    context: Context,
    accelerator: Accelerator = Accelerator.CPU,
    modelAssetName: String = "mnist_model.tflite",
) : AutoCloseable {

    private val model: CompiledModel =
        CompiledModel.create(
            context.assets,
            modelAssetName,
            CompiledModel.Options(accelerator),
            null,
        )

    private val inputBuffers = model.createInputBuffers()
    private val outputBuffers = model.createOutputBuffers()

    /**
     * @return 확률 내림차순 Top-K (digit, probability)
     */
    @RequiresApi(Build.VERSION_CODES.O)
    fun classify(bitmap: Bitmap, topK: Int): List<Pair<Int, Float>> {
        val input = preprocessToMnistFloat(bitmap)

        // 입력
        inputBuffers[0].writeFloat(input)

        // 추론
        model.run(inputBuffers, outputBuffers)

        // 출력 읽기
        val raw = outputBuffers[0].readFloat()
        val probs = softmax(raw)

        return probs
            .mapIndexed { idx, p -> idx to p }
            .sortedByDescending { it.second }
            .take(topK)
    }

    /**
     * HARDWARE 비트맵 방어: getPixels() 호출 전에 항상 SOFTWARE ARGB_8888로 변환
     */
    @RequiresApi(Build.VERSION_CODES.O)
    private fun Bitmap.ensureSoftwareARGB8888(): Bitmap {
        if (this.config == Bitmap.Config.ARGB_8888 && this.isMutable) return this
        if (this.config != Bitmap.Config.HARDWARE && this.config == Bitmap.Config.ARGB_8888) return this

        val safe = createBitmap(width, height)
        Canvas(safe).drawBitmap(this, 0f, 0f, null)
        return safe
    }

    @RequiresApi(Build.VERSION_CODES.O)
    private fun preprocessToMnistFloat(src: Bitmap): FloatArray {
        val safeSrc = src.ensureSoftwareARGB8888()
        val resized = safeSrc.scale(28, 28)

        val pixels = IntArray(28 * 28)
        resized.getPixels(pixels, 0, 28, 0, 0, 28, 28)

        val out = FloatArray(28 * 28)
        for (i in pixels.indices) {
            val argb = pixels[i]
            val red = (argb shr 16) and 0xFF
            val green = (argb shr 8) and 0xFF
            val blue = (argb) and 0xFF

            val gray = (0.299f * red + 0.587f * green + 0.114f * blue)
            out[i] = gray / 255f
        }
        return out
    }

    private fun softmax(logits: FloatArray): FloatArray {
        if (logits.isEmpty()) return logits
        val max = logits.maxOrNull() ?: 0f
        var sum = 0.0
        val exps = DoubleArray(logits.size)
        for (i in logits.indices) {
            val e = exp((logits[i] - max).toDouble())
            exps[i] = e
            sum += e
        }
        return FloatArray(logits.size) { i -> (exps[i] / sum).toFloat() }
    }

    override fun close() {
        inputBuffers.forEach { it.close() }
        outputBuffers.forEach { it.close() }
        model.close()
    }
}

/**
 * Coil로 SOFTWARE Bitmap 확보 (HARDWARE 비트맵 방지)
 */
suspend fun loadSoftwareBitmapWithCoil(
    context: Context,
    uri: Uri,
): Bitmap = withContext(Dispatchers.IO) {
    val imageLoader = SingletonImageLoader.get(context)

    val request = ImageRequest.Builder(context)
        .data(uri)
        .allowHardware(false) // HARDWARE 금지
        .build()

    val result = imageLoader.execute(request)
    if (result is SuccessResult) {
        return@withContext result.image.toBitmap()
    }
    error("Image load failed: $result")
}

 

 

class MainActivity : ComponentActivity() {
    @RequiresApi(Build.VERSION_CODES.O)
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContent {
            MaterialTheme {
                Surface(Modifier.fillMaxSize()) {
                    MnistScreen()
                }
            }
        }
    }
}

@RequiresApi(Build.VERSION_CODES.O)
@Composable
private fun MnistScreen() {
    val context = LocalContext.current

    var selectedUri by remember { mutableStateOf<Uri?>(null) }

    var results by remember { mutableStateOf<List<Pair<Int, Float>>>(emptyList()) }
    var error by remember { mutableStateOf<String?>(null) }
    var loading by remember { mutableStateOf(false) }

    val classifier = remember {
        MnistDigitClassifier(
            context = context.applicationContext,
            accelerator = Accelerator.CPU
        )
    }

    DisposableEffect(Unit) {
        onDispose { classifier.close() }
    }

    val picker = rememberLauncherForActivityResult(
        contract = ActivityResultContracts.GetContent()
    ) { uri: Uri? ->
        // 선택만 하고, 실제 로드 및 추론은 LaunchedEffect에서 처리
        selectedUri = uri
    }

    /**
     * URI가 바뀌면 Coil로 로드 + 분류
     */
    LaunchedEffect(selectedUri) {
        val uri = selectedUri ?: return@LaunchedEffect

        loading = true
        error = null
        results = emptyList()

        runCatching {
            val bmp = loadSoftwareBitmapWithCoil(context, uri)
            bmp
        }.onSuccess { bmp ->
            results = classifier.classify(bmp, topK = 5)
        }.onFailure { t ->
            error = t.message ?: "이미지 처리/추론 중 오류가 발생했습니다."
        }.also {
            loading = false
        }
    }

    Column(
        Modifier
            .fillMaxSize()
            .padding(16.dp)
            .verticalScroll(rememberScrollState())
    ) {
        Text("LiteRT MNIST", style = MaterialTheme.typography.headlineSmall)
        Spacer(Modifier.height(12.dp))

        Button(onClick = { picker.launch("image/*") }) {
            Text("이미지 선택")
        }

        Spacer(Modifier.height(16.dp))

        if (loading) {
            LinearProgressIndicator(Modifier.fillMaxWidth())
            Spacer(Modifier.height(12.dp))
        }

        selectedUri?.let { uri ->
            Text("선택한 이미지", style = MaterialTheme.typography.titleMedium)
            Spacer(Modifier.height(8.dp))
            AsyncImage(
                model = uri,
                contentDescription = "selected image",
                modifier = Modifier
                    .fillMaxWidth()
                    .height(220.dp)
            )
        }

        Spacer(Modifier.height(16.dp))

        error?.let {
            Text(it, color = MaterialTheme.colorScheme.error)
            Spacer(Modifier.height(12.dp))
        }

        if (results.isNotEmpty()) {
            Text("결과(Top-K)", style = MaterialTheme.typography.titleMedium)
            Spacer(Modifier.height(8.dp))
            results.forEach { (digit, prob) ->
                Text("digit=$digit   p=${"%.4f".format(prob)}")
            }
        } else {
            Text(
                "MNIST 스타일(검은 배경에 흰 숫자) 이미지일수록 결과가 잘 나옵니다.\n" +
                        "일반 사진은 MNIST 분포와 달라 정확도가 낮을 수 있습니다."
            )
        }
    }
}

 

간단하게 Jetpack Compose를 사용하여 화면을 구현해 보았습니다.

이미지 선택 버튼 클릭 후 갤러리에서 적합한 mnist 이미지를 선택하여 업로드하면, 화면에 업로드된 이미지와 결과를 볼 수 있습니다 😊