From 27c3215f0296148426e42cb45c1c954af50fb42c Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:14:16 -0500 Subject: [PATCH 01/19] chore(compression): remove model_facade.py Remove model_facade module and its tests, now superseded by model_editor. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 20 -- .../lite/micro/compression/model_facade.py | 276 ------------------ .../micro/compression/model_facade_test.py | 144 --------- 3 files changed, 440 deletions(-) delete mode 100644 tensorflow/lite/micro/compression/model_facade.py delete mode 100644 tensorflow/lite/micro/compression/model_facade_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 36725fac63c..9cbd5899047 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -169,26 +169,6 @@ py_test( ], ) -tflm_py_library( - name = "model_facade", - srcs = ["model_facade.py"], - deps = [ - "//tensorflow/lite/python:schema_py", - requirement("flatbuffers"), - ], -) - -py_test( - name = "model_facade_test", - size = "small", - srcs = ["model_facade_test.py"], - target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, - deps = [ - ":model_facade", - ":test_models", - ], -) - tflm_py_library( name = "spec", srcs = ["spec.py"], diff --git a/tensorflow/lite/micro/compression/model_facade.py b/tensorflow/lite/micro/compression/model_facade.py deleted file mode 100644 index 2e58d8080f1..00000000000 --- a/tensorflow/lite/micro/compression/model_facade.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""A facade for working with tflite.Model. - -This module provides convenient navigation, data type conversions, and -utilities for working with a tflite.Model, which can be tedious and verbose to -work with directly. - -Usage: - model = model_facade.read(flatbuffer) - # manipulate - new_flatbuffer = model.compile() -""" - -from __future__ import annotations - -import flatbuffers -import numpy as np -from numpy.typing import NDArray -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -from typing import ByteString, Generic, TypeVar - -_IteratorTo = TypeVar("_IteratorTo") - - -class _Iterator(Generic[_IteratorTo]): - - def __init__(self, sequence, cls, parent): - self._sequence = sequence - self._cls = cls - self._index = 0 - self._parent = parent - - def __getitem__(self, key) -> _IteratorTo: - return self._cls(self._sequence[key], key, self._parent) - - def __len__(self): - return len(self._sequence) - - def __iter__(self): - self._index = 0 - return self - - def __next__(self): - try: - result = self[self._index] - self._index += 1 - return result - except IndexError: - raise StopIteration - - -class _IndirectIterator(Generic[_IteratorTo]): - - def __init__(self, indices, sequence): - self._indices = indices - self._index = 0 - self._sequence = sequence - - def __getitem__(self, key) -> _IteratorTo: - index = self._indices[key] - return self._sequence[index] - - def __len__(self): - return len(self._indices) - - def __iter__(self): - self._index = 0 - return self - - def __next__(self): - try: - result = self[self._index] - self._index += 1 - return result - except IndexError: - raise StopIteration - - -class _Operator: - - def __init__(self, operator, index, subgraph): - self.operator = operator - self.index = index - self.subgraph = subgraph - - @property - def opcode(self) -> tflite.OperatorCodeT: - return self.subgraph.model.operatorCodes[self.operator.opcodeIndex] - - @property - def inputs(self): - return _IndirectIterator(self.operator.inputs, self.subgraph.tensors) - - -_NP_DTYPES = { - tflite.TensorType.FLOAT16: np.dtype(" _Buffer: - return self.subgraph.model.buffers[self._tensor_t.buffer] - - @property - def data(self) -> bytes: - return self.buffer.data - - @property - def dtype(self) -> np.dtype: - return _NP_DTYPES[self._tensor_t.type] - - @property - def array(self) -> np.ndarray: - """Returns an array created from the Tensor's data, type, and shape. - - Note the bytes in the data buffer and the Tensor's type and shape may be - inconsistent, and thus the returned array invalid, if the data buffer has - been altered according to the compression schema, in which the data buffer - is an array of fixed-width, integer fields. - """ - return np.frombuffer(self.data, - dtype=self.dtype).reshape(self._tensor_t.shape) - - @property - def quantization(self) -> tflite.QuantizationParametersT | None: - return self._tensor_t.quantization - - -class _Buffer: - - def __init__(self, buffer_t: tflite.BufferT, index, model): - self._buffer_t = buffer_t - self.index = index - self.model = model - - @property - def data(self) -> bytes: - return bytes(self._buffer_t.data) - - @data.setter - def data(self, value: ByteString): - self._buffer_t.data = list(value) - - def extend(self, values: NDArray): - self._buffer_t.data.extend(values.tobytes()) - - -class _Subgraph: - - def __init__(self, subgraph_t: tflite.SubGraphT, index: int, model: _Model): - self._subgraph_t = subgraph_t - self.index = index - self.model = model - - @property - def operators(self) -> _Iterator[_Operator]: - return _Iterator(self._subgraph_t.operators, _Operator, parent=self) - - @property - def tensors(self) -> _Iterator[_Tensor]: - return _Iterator(self._subgraph_t.tensors, _Tensor, parent=self) - - -class _Model: - """A facade for manipulating tflite.Model. - """ - - def __init__(self, model_t: tflite.ModelT): - self._model_t = model_t - - def compile(self) -> bytearray: - """Returns a tflite.Model flatbuffer. - """ - size_hint = 4 * 2**10 - builder = flatbuffers.Builder(size_hint) - builder.Finish(self._model_t.Pack(builder)) - return builder.Output() - - def add_buffer(self) -> _Buffer: - """Adds a buffer to the model. - """ - buffer = tflite.BufferT() - buffer.data = [] - self._model_t.buffers.append(buffer) - index = len(self._model_t.buffers) - 1 - return _Buffer(buffer, index, self._model_t) - - def add_metadata(self, key, value): - """Adds a key-value pair, writing value to a newly created buffer. - """ - metadata = tflite.MetadataT() - metadata.name = key - buffer = self.add_buffer() - buffer.data = value - metadata.buffer = buffer.index - self._model_t.metadata.append(metadata) - - @property - def metadata(self) -> dict[str, _Buffer]: - """Returns the model's metadata as a dictionary to Buffer objects. - """ - result = {} - for m in self._model_t.metadata: - name = m.name.decode("utf-8") # type: ignore (fb library is wrong) - buffer = _Buffer(self._model_t.buffers[m.buffer], m.buffer, - self._model_t) - result[name] = buffer - - return result - - @property - def operatorCodes(self): - return self._model_t.operatorCodes - - @property - def subgraphs(self) -> _Iterator[_Subgraph]: - return _Iterator(self._model_t.subgraphs, _Subgraph, parent=self) - - @property - def buffers(self) -> _Iterator[_Buffer]: - return _Iterator(self._model_t.buffers, _Buffer, parent=self) - - -def read(buffer: ByteString): - """Reads a tflite.Model and returns a model facade. - """ - schema_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) - return _Model(schema_model) diff --git a/tensorflow/lite/micro/compression/model_facade_test.py b/tensorflow/lite/micro/compression/model_facade_test.py deleted file mode 100644 index 87e71fa968b..00000000000 --- a/tensorflow/lite/micro/compression/model_facade_test.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import unittest -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -from tflite_micro.tensorflow.lite.micro.compression import model_facade -from tflite_micro.tensorflow.lite.micro.compression import test_models - -TEST_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - 1: { - "name": "metadata1", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, # ADD - "inputs": ( - 1, - 2, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, # FULLY_CONNECTED - "inputs": ( - 3, - 4, - 5, - ), - "outputs": (6, ), - }, - }, - "tensors": { - 0: { - "name": "tensor0", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "name": "tensor1", - "shape": (8, 1), - "type": tflite.TensorType.INT16, - "buffer": 2, - }, - 2: { - "name": "tensor2", - "shape": (4, 1), - "type": tflite.TensorType.INT32, - "buffer": 3, - }, - 3: { - "name": "tensor3", - "shape": (2, 1), - "type": tflite.TensorType.INT64, - "buffer": 4, - }, - }, - }, - }, - "buffers": { - 0: None, - 1: np.array(range(16), dtype=np.dtype(" Date: Sun, 24 May 2026 23:16:20 -0500 Subject: [PATCH 02/19] refactor(compression): replace test_models with model_editor in compress_test Replace dictionary-based test_models.build() with model_editor's declarative API for building test models. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 1 - .../lite/micro/compression/compress_test.py | 218 ++++++------------ 2 files changed, 66 insertions(+), 153 deletions(-) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 9cbd5899047..b5042f08130 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -162,7 +162,6 @@ py_test( ":metadata_py", ":model_editor", ":spec", - ":test_models", "//tensorflow/lite/python:schema_py", requirement("bitarray"), requirement("numpy"), diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index ee10a75f36d..81bbdab3293 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -21,7 +21,6 @@ from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec -from tflite_micro.tensorflow.lite.micro.compression import test_models from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite @@ -170,153 +169,70 @@ def test_multiple_tables_with_padding(self): self.assertEqual(result, expected_output) -# yapf: disable -TEST_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 0, - "inputs": ( - 0, - 1, - ), - "outputs": (2, ), - }, - }, - "tensors": { - 0: { - "shape": (16, 1), - "type": tflite.TensorType.UINT8, - "buffer": 1, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 1: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 2, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 2: { - "shape": (16, 1), - "type": tflite.TensorType.INT16, - "buffer": 3, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 3: { - "shape": (16, 1), - "type": tflite.TensorType.INT32, - "buffer": 4, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 4: { - "shape": (16, 1), - "type": tflite.TensorType.INT32, - "buffer": 5, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 5: { - "shape": (4, 5), - "type": tflite.TensorType.INT16, - "buffer": 6, - "quantization": { - "quantized_dimension": 1, - "scale": (1, 1, 1, 1, 1), - "zero_point": (0, 0, 0, 0, 0), - }, - }, - 6: { - "shape": (5, 4), - "type": tflite.TensorType.INT16, - "buffer": 7, - "quantization": { - "quantized_dimension": 0, - "scale": (1, 1, 1, 1, 1), - "zero_point": (0, 0, 0, 0, 0), - }, - }, - 7: { - "shape": (5, 4), - "type": tflite.TensorType.INT16, - "buffer": 8, - "quantization": { - "quantized_dimension": 0, - "scale": (1,), - "zero_point": (0,), - }, - }, - 8: { - "shape": (16, 1), - "type": tflite.TensorType.UINT8, - "buffer": 9, - }, - }, - }, - }, - "buffers": { - 0: None, - - 1: np.array(range(16), dtype=np.dtype(" Date: Sun, 24 May 2026 23:17:18 -0500 Subject: [PATCH 03/19] chore(compression): remove test_models.py Remove test_models module and its tests, now superseded by model_editor. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 21 -- .../lite/micro/compression/test_models.py | 190 ------------------ .../micro/compression/test_models_test.py | 32 --- 3 files changed, 243 deletions(-) delete mode 100644 tensorflow/lite/micro/compression/test_models.py delete mode 100644 tensorflow/lite/micro/compression/test_models_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index b5042f08130..32f939a49dc 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -205,27 +205,6 @@ py_test( ], ) -tflm_py_library( - name = "test_models", - srcs = ["test_models.py"], - deps = [ - "//tensorflow/lite/python:schema_py", - requirement("flatbuffers"), - requirement("numpy"), - ], -) - -py_test( - name = "test_models_test", - size = "small", - srcs = ["test_models_test.py"], - target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, - deps = [ - ":test_models", - "//tensorflow/lite/python:schema_py", - ], -) - tflm_py_library( name = "tensor_type", srcs = ["tensor_type.py"], diff --git a/tensorflow/lite/micro/compression/test_models.py b/tensorflow/lite/micro/compression/test_models.py deleted file mode 100644 index 80286d17359..00000000000 --- a/tensorflow/lite/micro/compression/test_models.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""Tools for constructing flatbuffers for testing. - -This module provides tools for constructing .tflite flatbuffers from a Python -dictionary representation of a model, a prototype of which can be found in -EXAMPLE_MODEL. - -Example usage: - model_definition = {...} # use EXAMPLE_MODEL as prototype - flatbuffer: bytearray = test_models.build(model_definition) -""" - -# This module must remain low-level and independent from any helpers in this -# project which make constructing model and flatbuffers easier, because this -# module is used to define tests for those helpers. - -import flatbuffers -import numpy as np -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite - -EXAMPLE_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, - "inputs": ( - 0, - 1, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, - "inputs": ( - 3, - 2, - ), - "outputs": (4, ), - }, - }, - "tensors": { - 0: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 2: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 3: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - "quantization": { - "quantized_dimension": 0, - }, - }, - }, - }, - }, - "buffers": { - 0: None, - 1: np.array(range(16), dtype=np.dtype(" bytearray: - """Builds a .tflite flatbuffer from a model definition. - - Args: - model_definition: A dictionary representation of the model, a prototype of - which can be found in the EXAMPLE_MODEL attribute of this module. - - Returns: - A tflite flatbuffer. - """ - root = tflite.ModelT() - description = model_definition.get("description") - if description is not None: - root.description = description - - root.operatorCodes = [] - for id, operator_code in model_definition["operator_codes"].items(): - assert id == len(root.operatorCodes) - opcode_t = tflite.OperatorCodeT() - root.operatorCodes.append(opcode_t) - opcode_t.builtinCode = operator_code["builtin_code"] - - root.metadata = [] - if "metadata" in model_definition: - for _, metadata in model_definition["metadata"].items(): - metadata_t = tflite.MetadataT() - metadata_t.name = metadata["name"] - metadata_t.buffer = metadata["buffer"] - root.metadata.append(metadata_t) - - root.subgraphs = [] - for id, subgraph in model_definition["subgraphs"].items(): - assert id == len(root.subgraphs) - subgraph_t = tflite.SubGraphT() - root.subgraphs.append(subgraph_t) - - subgraph_t.operators = [] - for id, operator in subgraph["operators"].items(): - assert id == len(subgraph_t.operators) - operator_t = tflite.OperatorT() - operator_t.opcodeIndex = operator["opcode_index"] - operator_t.inputs = operator["inputs"] - operator_t.outputs = operator["outputs"] - subgraph_t.operators.append(operator_t) - - subgraph_t.tensors = [] - for id, tensor in subgraph["tensors"].items(): - assert id == len(subgraph_t.tensors) - tensor_t = tflite.TensorT() - tensor_t.name = tensor.get("name", None) - tensor_t.shape = tensor["shape"] - tensor_t.type = tensor["type"] - tensor_t.buffer = tensor["buffer"] - - if "quantization" in tensor: - tensor_t.quantization = tflite.QuantizationParametersT() - tensor_t.quantization.quantizedDimension = \ - tensor["quantization"].get("quantized_dimension", None) - tensor_t.quantization.scale = \ - tensor["quantization"].get("scale", None) - tensor_t.quantization.zeroPoint = \ - tensor["quantization"].get("zero_point", None) - - subgraph_t.tensors.append(tensor_t) - - root.buffers = [] - for id, data in model_definition["buffers"].items(): - assert id == len(root.buffers) - buffer_t = tflite.BufferT() - - if data is None: - buffer_t.data = [] - elif isinstance(data, np.ndarray): - array = data.astype(data.dtype.newbyteorder("<")) # ensure little-endian - buffer_t.data = list(array.tobytes()) - else: - raise TypeError(f"buffer_id {id} must be None or an np.ndarray") - - root.buffers.append(buffer_t) - - size_hint = 1 * 2**20 - builder = flatbuffers.Builder(size_hint) - builder.Finish(root.Pack(builder)) - flatbuffer = builder.Output() - return flatbuffer diff --git a/tensorflow/lite/micro/compression/test_models_test.py b/tensorflow/lite/micro/compression/test_models_test.py deleted file mode 100644 index d7e961c2dd9..00000000000 --- a/tensorflow/lite/micro/compression/test_models_test.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from tflite_micro.tensorflow.lite.micro.compression import test_models -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite - - -class TestBuild(unittest.TestCase): - - def setUp(self): - self.flatbuffer = test_models.build(test_models.EXAMPLE_MODEL) - - def testNotDegenerate(self): - model = tflite.ModelT.InitFromPackedBuf(self.flatbuffer, 0) - self.assertEqual(model.operatorCodes[0].builtinCode, - tflite.BuiltinOperator.FULLY_CONNECTED) - - -if __name__ == "__main__": - unittest.main() From bfac49aea553d60d79aab65c40d77e6ed5100f86 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:19:41 -0500 Subject: [PATCH 04/19] feat(compression): add DECODE operator types and metadata Add decode module with DecodeType constants and DecodeCommonMetadata, per the TFLM DECODE Operator Design document. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 14 + tensorflow/lite/micro/compression/decode.py | 240 ++++++++++++++++++ .../lite/micro/compression/decode_test.py | 155 +++++++++++ 3 files changed, 409 insertions(+) create mode 100644 tensorflow/lite/micro/compression/decode.py create mode 100644 tensorflow/lite/micro/compression/decode_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 32f939a49dc..8c22f9d9c71 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -247,6 +247,20 @@ tflm_py_test( ], ) +tflm_py_library( + name = "decode", + srcs = ["decode.py"], +) + +tflm_py_test( + name = "decode_test", + size = "small", + srcs = ["decode_test.py"], + deps = [ + ":decode", + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/decode.py b/tensorflow/lite/micro/compression/decode.py new file mode 100644 index 00000000000..df8428310a3 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode.py @@ -0,0 +1,240 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DECODE compression module.""" + +# Implements the DECODE operator compression scheme described in the +# "TFLM DECODE Operator Design" document, revised May 20, 2025. +# +# The DECODE operator transforms an encoded tensor, alongside a paired +# ancillary data tensor, into a tensor ready for use as input to any +# operator. For example, an encoded tensor might contain compressed +# data, while the paired ancillary data tensor holds the information +# necessary for decompression. The DECODE operator's output is a fully +# decompressed tensor. +# +# DECODE operators are inserted into the TfLite model subgraph +# immediately before each operation that uses a decodable tensor as +# input. +# +# Ancillary Data Tensor +# +# The ancillary data tensor contains the information necessary for +# decoding. It begins with a 16-byte DECODE Common Metadata (DCM) +# header, followed by decode-type-specific ancillary data. +# +# DECODE Common Metadata (DCM) +# +# Byte 0: Decode type +# 0-127: TFLM-supported decode operations (see below) +# 128-255: Custom operations requiring application-registered +# handlers +# +# Supported decode types: +# +# 0: LUT decompression +# All TFLM tensor types supported in reference and optimized +# code. +# +# 1: Huffman decompression using Xtensa format decode tables +# INT8 and INT16 tensor types only, in reference and optimized +# code. +# +# 2: Pruning decompression +# All TFLM tensor types supported in reference and optimized +# code. +# +# 3-127: Reserved +# +# 128-255: Custom decode types +# Requires user-supplied encoding module and decoding ancillary +# data. +# +# Byte 1: DCM version (currently 1) +# +# Bytes 2-3: Reserved +# +# Bytes 4-15: User-defined +# Used by TFLM decode types to avoid requiring additional alignment +# of metadata or ancillary data. +# +# The 16-byte DCM size ensures that subsequent metadata and ancillary +# data are 128-bit aligned, which is required for some optimized +# decoding operations such as Xtensa LUT decompression. +# +# For TFLM decode types, ancillary data starts immediately after the +# DCM. For custom decode types, the location is determined by +# user-defined metadata. + +from dataclasses import dataclass +from typing import Protocol + + +class DecodeType: + """Decode operation type (0-255). + + Use predefined constants for built-in types or DecodeType.custom() + for custom types: + DecodeType.LUT # 0 + DecodeType.HUFFMAN # 1 + DecodeType.PRUNING # 2 + DecodeType.custom(200) # Custom type 128-255 + """ + + # Built-in decode types (class variables set after class definition) + LUT: 'DecodeType' + HUFFMAN: 'DecodeType' + PRUNING: 'DecodeType' + + def __init__(self, code: int, name: str = None): + """Initialize DecodeType. + + Args: + code: Integer code 0-255 + name: Optional name for the type. If not provided: + - Codes 0-127: Named "TYPE_{code}" + - Codes 128-255: Named "CUSTOM_{code}" + """ + if not 0 <= code <= 255: + raise ValueError(f"Decode type must be 0-255, got {code}") + self.code = code + + # Auto-generate name if not provided + if name is None: + self.name = f"CUSTOM_{code}" if code >= 128 else f"TYPE_{code}" + else: + self.name = name + + self._is_custom = code >= 128 + + @property + def is_custom(self) -> bool: + """True if this is a custom decode type (128-255).""" + return self._is_custom + + @classmethod + def custom(cls, code: int) -> 'DecodeType': + """Create custom decode type (128-255). + + Args: + code: Integer code 128-255 + + Returns: + DecodeType with name CUSTOM_{code} + """ + if not 128 <= code <= 255: + raise ValueError(f"Custom decode type must be 128-255, got {code}") + return cls(code) + + def __int__(self): + """Convert to integer for serialization.""" + return self.code + + def __eq__(self, other): + if isinstance(other, DecodeType): + return self.code == other.code + return self.code == other + + def __repr__(self): + return f"DecodeType.{self.name}({self.code})" + + +# Define built-in decode type constants +DecodeType.LUT = DecodeType(0, "LUT") +DecodeType.HUFFMAN = DecodeType(1, "HUFFMAN") +DecodeType.PRUNING = DecodeType(2, "PRUNING") + + +@dataclass +class DecodeCommonMetadata: + """16-byte DECODE Common Metadata (DCM) header. + + Attributes: + decode_type: Decode operation type. Use DecodeType constants or + DecodeType.custom(code) for custom types. + version: DCM version (currently 1). + user_data: 12 bytes of user-defined data (bytes 4-15 of DCM). Used by TFLM + decode types to avoid requiring additional alignment of metadata + or ancillary data. + """ + decode_type: DecodeType + version: int = 1 + user_data: bytes = b'\x00' * 12 + + def to_bytes(self) -> bytes: + """Serialize DCM to 16-byte sequence.""" + decode_code = int(self.decode_type) + if not 0 <= self.version <= 255: + raise ValueError(f"version must be 0-255, got {self.version}") + if len(self.user_data) < 12: + # Pad with zeros if user_data is too short + user_data = self.user_data + b'\x00' * (12 - len(self.user_data)) + else: + user_data = self.user_data[:12] + + result = bytearray(16) + result[0] = decode_code + result[1] = self.version + # bytes 2-3 remain zero (reserved) + result[4:16] = user_data + return bytes(result) + + +class AncillaryDataSerializer(Protocol): + """Protocol for objects that can serialize ancillary data.""" + + def to_bytes(self) -> bytes: + ... + + +@dataclass +class AncillaryDataTensor: + """Complete Ancillary Data Tensor (ADT): DCM + decode-type-specific data. + + The ADT is stored as a buffer in the TFLite model. It begins with a 16-byte + DCM header, followed by decode-type-specific ancillary data. + + Attributes: + dcm: The DECODE Common Metadata header. + ancillary_data: The decode-type-specific ancillary data, either as raw bytes + or as an object implementing the AncillaryDataSerializer + protocol. May be None if only the DCM is needed. + """ + dcm: DecodeCommonMetadata + ancillary_data: AncillaryDataSerializer | bytes | None = None + + def with_ancillary_data( + self, data: AncillaryDataSerializer | bytes) -> 'AncillaryDataTensor': + """Create new ADT with ancillary data added. + + Args: + data: Ancillary data to add, either as raw bytes or as an object + implementing AncillaryDataSerializer. + + Returns: + New AncillaryDataTensor with the specified ancillary data. + """ + return AncillaryDataTensor(self.dcm, data) + + def to_bytes(self) -> bytes: + """Serialize entire ADT to bytes. + + Returns: + Byte sequence containing DCM followed by ancillary data (if present). + """ + dcm_bytes = self.dcm.to_bytes() + if self.ancillary_data is None: + return dcm_bytes + if isinstance(self.ancillary_data, bytes): + return dcm_bytes + self.ancillary_data + return dcm_bytes + self.ancillary_data.to_bytes() diff --git a/tensorflow/lite/micro/compression/decode_test.py b/tensorflow/lite/micro/compression/decode_test.py new file mode 100644 index 00000000000..eca3b42b2b4 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_test.py @@ -0,0 +1,155 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tflite_micro.tensorflow.lite.micro.compression import decode + + +class TestDecodeCommonMetadata(unittest.TestCase): + + def testBasicSerialization(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT) + result = dcm.to_bytes() + + # Should be exactly 16 bytes + self.assertEqual(len(result), 16) + + # Byte 0: decode_type + self.assertEqual(result[0], 0) + + # Byte 1: version (default 1) + self.assertEqual(result[1], 1) + + # Bytes 2-3: reserved (should be zero) + self.assertEqual(result[2], 0) + self.assertEqual(result[3], 0) + + # Bytes 4-15: user_data (default all zeros) + self.assertEqual(result[4:16], b'\x00' * 12) + + def testCustomVersion(self): + dcm = decode.DecodeCommonMetadata(decode_type=1, version=2) + result = dcm.to_bytes() + + self.assertEqual(result[0], 1) + self.assertEqual(result[1], 2) + + def testUserData(self): + user_data = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + self.assertEqual(result[4:16], user_data) + + def testUserDataPadding(self): + # User data shorter than 12 bytes should be padded with zeros + user_data = b'\x01\x02\x03' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + expected = b'\x01\x02\x03' + b'\x00' * 9 + self.assertEqual(result[4:16], expected) + + def testUserDataTruncation(self): + # User data longer than 12 bytes should be truncated + user_data = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + self.assertEqual(result[4:16], user_data[:12]) + + def testDecodeTypeRange(self): + # Valid decode types: 0-255 + decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT).to_bytes() + decode.DecodeCommonMetadata(decode_type=decode.DecodeType(127)).to_bytes() + decode.DecodeCommonMetadata( + decode_type=decode.DecodeType.custom(255)).to_bytes() + + # Invalid decode types should raise ValueError + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=decode.DecodeType(-1)).to_bytes() + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata( + decode_type=decode.DecodeType(256)).to_bytes() + + def testVersionRange(self): + # Valid versions: 0-255 + decode.DecodeCommonMetadata(decode_type=0, version=0).to_bytes() + decode.DecodeCommonMetadata(decode_type=0, version=255).to_bytes() + + # Invalid versions should raise ValueError + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=0, version=-1).to_bytes() + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=0, version=256).to_bytes() + + +class TestAncillaryDataTensor(unittest.TestCase): + + def testDcmOnly(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT) + adt = decode.AncillaryDataTensor(dcm) + result = adt.to_bytes() + + # Should be just the 16-byte DCM + self.assertEqual(len(result), 16) + self.assertEqual(result, dcm.to_bytes()) + + def testWithBytesAncillaryData(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.HUFFMAN) + ancillary = b'\xaa\xbb\xcc\xdd' + adt = decode.AncillaryDataTensor(dcm, ancillary) + result = adt.to_bytes() + + # Should be DCM + ancillary data + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], ancillary) + + def testWithAncillaryDataMethod(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.PRUNING) + adt = decode.AncillaryDataTensor(dcm) + + ancillary = b'\x11\x22\x33\x44' + adt_with_data = adt.with_ancillary_data(ancillary) + result = adt_with_data.to_bytes() + + # Original ADT should be unchanged + self.assertEqual(adt.to_bytes(), dcm.to_bytes()) + + # New ADT should have ancillary data + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], ancillary) + + def testWithSerializerProtocol(self): + # Test with an object that implements AncillaryDataSerializer + class MockSerializer: + + def to_bytes(self): + return b'\xff\xee\xdd\xcc' + + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType(3)) + serializer = MockSerializer() + adt = decode.AncillaryDataTensor(dcm, serializer) + result = adt.to_bytes() + + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], b'\xff\xee\xdd\xcc') + + +if __name__ == '__main__': + unittest.main() From fa2a3a83836b13b485baeea936cbcb48a519947b Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:20:50 -0500 Subject: [PATCH 05/19] feat(compression): add Compressor protocol Define the plugin interface for compression methods. Each compressor implements the Compressor protocol with a compress() method that returns encoded data and ancillary data. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 10 +++ .../lite/micro/compression/compressor.py | 80 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 tensorflow/lite/micro/compression/compressor.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 8c22f9d9c71..a4e45249b90 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -261,6 +261,16 @@ tflm_py_test( ], ) +tflm_py_library( + name = "compressor", + srcs = ["compressor.py"], + deps = [ + ":decode", + ":model_editor", + ":spec", + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/compressor.py b/tensorflow/lite/micro/compression/compressor.py new file mode 100644 index 00000000000..3d5a635eb09 --- /dev/null +++ b/tensorflow/lite/micro/compression/compressor.py @@ -0,0 +1,80 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compression plugin interface.""" + +from dataclasses import dataclass +from typing import Protocol + +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class CompressionError(Exception): + """Raised when compression fails for the reason documented in the message.""" + + def __init__(self, message, wrapped_exception=None): + if wrapped_exception: + super().__init__(f"{message}: {str(wrapped_exception)}") + else: + super().__init__(message) + self.original_exception = wrapped_exception + + +@dataclass +class CompressionResult: + """Result of compressing a tensor. + + Attributes: + encoded_data: The compressed tensor data (e.g., packed indices for LUT). + ancillary_data: The complete ancillary data tensor bytes (DCM + type-specific + data). This is the full buffer contents for the ancillary + tensor. + """ + encoded_data: bytes + ancillary_data: bytes + + +class Compressor(Protocol): + """Protocol that compression plugins must implement. + + Each compression method (LUT, Huffman, Pruning) provides a class implementing + this protocol. The compress() function uses duck typing to call the plugin. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """The DecodeType constant for this compression method.""" + ... + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> CompressionResult: + """Compress a tensor according to the specified method. + + Args: + tensor: The tensor to compress. Must have data (tensor.array is not None) + and quantization parameters for axis inference. + method: The compression method spec (e.g., LookUpTableCompression). + + Returns: + CompressionResult with encoded tensor data and ancillary data bytes. + + Raises: + CompressionError: If compression fails (e.g., too many unique values + for specified bitwidth, missing quantization, etc.). + """ + ... From 5c3e4ba002190d8ffe897b884234bf4f9ff6b081 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:22:36 -0500 Subject: [PATCH 06/19] feat(compression): add LUT compression plugin Implement LutCompressor using the Compressor protocol. Lookup table compression replaces tensor values with indices into a table of unique values, producing packed indices and ancillary data in the format expected by the TFLM DECODE kernel. Supports per-tensor and per-channel compression, sizes value tables to actual unique count, and handles unquantized tensors. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 33 ++ tensorflow/lite/micro/compression/lut.py | 318 ++++++++++++++ tensorflow/lite/micro/compression/lut_test.py | 405 ++++++++++++++++++ 3 files changed, 756 insertions(+) create mode 100644 tensorflow/lite/micro/compression/lut.py create mode 100644 tensorflow/lite/micro/compression/lut_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index a4e45249b90..f7f411f196f 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -271,6 +271,39 @@ tflm_py_library( ], ) +tflm_py_library( + name = "lut", + srcs = ["lut.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + requirement("bitarray"), + requirement("numpy"), + ], +) + +tflm_py_test( + name = "lut_test", + size = "small", + srcs = ["lut_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + deps = [ + ":compressor", + ":decode", + ":lut", + ":model_editor", + ":spec", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/lut.py b/tensorflow/lite/micro/compression/lut.py new file mode 100644 index 00000000000..def34059ac5 --- /dev/null +++ b/tensorflow/lite/micro/compression/lut.py @@ -0,0 +1,318 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LUT (Look-Up Table) compression plugin.""" + +import sys +from dataclasses import dataclass, field +from typing import Optional + +import bitarray +import bitarray.util +import numpy as np + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +@dataclass +class LutCompressedArray: + """Intermediate representation of LUT-compressed data. + + Attributes: + compression_axis: The axis along which compression was performed, or None + for per-tensor compression. + lookup_tables: List of value lookup tables. One table for per-tensor + compression, or one per channel for per-channel compression. + indices: Array of indices into the lookup tables, same shape as original. + """ + compression_axis: Optional[int] = None + lookup_tables: list[np.ndarray] = field(default_factory=list) + indices: np.ndarray = field(default_factory=lambda: np.array([])) + + @property + def index_bitwidth(self) -> int: + """Returns the number of bits required to encode the indices.""" + if self.indices is None or self.indices.size == 0: + raise ValueError("No indices to compute bitwidth from") + max_index = int(np.max(self.indices)) + return max_index.bit_length() or 1 + + +@dataclass +class LutAncillaryData: + """LUT-specific ancillary data matching C++ decode_state_lut.cc format. + + The LUT ancillary data uses the DCM user_data bytes (4-15) plus value tables: + - Byte 4: LUT version (currently 1) + - Byte 5: Params (lower 3 bits = bitwidth, 1-7) + - Byte 6: Value table channel stride (elements per channel) + - Bytes 7-15: Reserved (zeros) + - Bytes 16+: Value tables (concatenated, stride elements per channel) + + Attributes: + lut_version: LUT format version (currently 1). + bitwidth: Number of bits per index (1-7). + value_table_stride: Number of elements per channel in value tables. + value_tables: Packed value table data following the DCM. + """ + lut_version: int = 1 + bitwidth: int = 4 + value_table_stride: int = 16 + value_tables: bytes = b'' + + def __post_init__(self): + if not 1 <= self.bitwidth <= 7: + raise ValueError(f"bitwidth must be 1-7, got {self.bitwidth}") + if not 0 <= self.value_table_stride <= 128: + raise ValueError( + f"value_table_stride must be 0-128, got {self.value_table_stride}") + + def to_user_data(self) -> bytes: + """Serialize to 12-byte user_data for DCM bytes 4-15.""" + user_data = bytearray(12) + user_data[0] = self.lut_version + user_data[1] = self.bitwidth & 0x07 + user_data[2] = self.value_table_stride + # Bytes 3-11 (DCM bytes 7-15) remain zero (reserved) + return bytes(user_data) + + def to_bytes(self) -> bytes: + """Serialize for use as AncillaryDataTensor.ancillary_data.""" + # This returns the type-specific data that follows the DCM header. + # For LUT, that's just the value tables since user_data is in DCM. + return self.value_tables + + +def compress_array(tensor: np.ndarray, + axis: Optional[int]) -> LutCompressedArray: + """Compresses the given tensor using lookup tables. + + Args: + tensor: The tensor to be compressed. + axis: The axis along which to compress. If an axis is given, a lookup table + is created for each slice along the axis. If axis is None, a single + lookup table is used for the entire tensor. + + Compressing a tensor with a lookup table per slice along a particular + axis is analogous to quantizing a tensor with different quantization + parameters per slice along a particular axis (dimension). + + Returns: + LutCompressedArray containing lookup tables and indices. + """ + compressed = LutCompressedArray() + compressed.compression_axis = axis + + if axis is None: + # Compute unique values and indices for the entire tensor + values, indices = np.unique(tensor, return_inverse=True) + compressed.lookup_tables.append(values) + compressed.indices = indices.reshape(tensor.shape) + else: + # Iterate over slices along the compression axis + slice_indices = [] + for slice in np.moveaxis(tensor, axis, 0): + values, indices = np.unique(slice, return_inverse=True) + compressed.lookup_tables.append(values) + indices = indices.reshape(slice.shape) + slice_indices.append(indices) + + # Reconstruct a tensor of indices from the slices + stacked = np.stack(slice_indices, axis=0) + compressed.indices = np.moveaxis(stacked, 0, axis) + + return compressed + + +def identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: + """Determines the axis along which to compress. + + The axis along which to compress is inferred from the tensor's quantization + parameters. Unquantized tensors use per-tensor compression. + + Args: + tensor: The tensor to analyze. + + Returns: + The axis along which to compress, or None to indicate one value table for + the entire tensor. + + Raises: + CompressionError: If the axis cannot be determined from quantization. + """ + q = tensor.quantization + if q is None: + return None + + # model_editor wraps quantization, access scales/axis from wrapper + scales = q.scales if isinstance(q.scales, list) else [q.scales] + quantization_channels = len(scales) + + if quantization_channels == 1: + return None + + if q.axis is not None and q.axis < len(tensor.shape): + if quantization_channels == tensor.shape[q.axis]: + return q.axis + + raise compressor.CompressionError( + "Invalid or no quantization parameters from which to " + "infer the axis along which tensor should be compressed.") + + +def check_bitwidth(compressed: int, specified: int, tensor_spec: spec.Tensor): + """Validates that the specified bitwidth is sufficient. + + It is an error if the bitwidth required to compress a tensor exceeds the + specified bitwith, and a warning if the tensor can be compressed in less than + the specified bitwidth. The latter is allowed, and is not an error, to permit + testing with larger bitwidths without re-binning a model. + + Args: + compressed: The bitwidth required by the compressed data. + specified: The bitwidth specified in the compression spec. + tensor_spec: The tensor spec, for error messages. + + Raises: + CompressionError: If specified bitwidth is too small. + """ + if compressed > specified: + raise compressor.CompressionError( + f"index_bitwidth too small: {compressed} bits needed to " + f"enumerate unique values in tensor specified in {tensor_spec}") + elif compressed < specified: + print( + f"warning: index_bitwidth too large: only {compressed} " + f"bits needed to enumerate unique values in tensor specified in " + f"{tensor_spec}", + file=sys.stderr) + + +def pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: + """Packs indices into a bytearray using bitwidth-sized fields. + + Args: + indices: Array of indices to pack. + bitwidth: Number of bits per index. + + Returns: + Packed bytes with indices in big-endian bit order. + """ + endianness = "big" + bits = bitarray.bitarray(endian=endianness) + for i in indices.ravel(): + bits.extend( + bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) + return bits.tobytes() + + +def pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytes: + """Packs the value tables of a LutCompressedArray. + + Pack the value tables of a LutCompressedArray into a bytes object in the + format writable to a value_table buffer in the .tflite flatbuffer. The + tables are concatenated. + + Args: + tables: List of numpy arrays containing lookup table values. + table_len: Length to pad each table to (typically 2**bitwidth). + + Returns: + Packed bytes containing all tables concatenated. + """ + buffer = bytearray() + for t in tables: + padding_needed = table_len - len(t) + padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) + buffer.extend(padded.tobytes()) + return bytes(buffer) + + +class LutCompressor: + """LUT compression plugin implementing the Compressor protocol.""" + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.LUT.""" + return decode.DecodeType.LUT + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using LUT compression. + + Args: + tensor: The tensor to compress. + method: Must be a LookUpTableCompression instance. + + Returns: + CompressionResult with packed indices and ancillary data. + + Raises: + CompressionError: If compression fails. + """ + if not isinstance(method, spec.LookUpTableCompression): + raise compressor.CompressionError( + f"LutCompressor requires LookUpTableCompression, got {type(method)}") + + if tensor.array is None: + raise compressor.CompressionError("Tensor has no data to compress") + + spec_bitwidth = method.index_bitwidth + axis = identify_compression_axis(tensor) + compressed = compress_array(tensor.array, axis) + # Note: check_bitwidth requires a spec.Tensor but we don't have it here. + # We'll do a simpler check. + actual_bitwidth = compressed.index_bitwidth + if actual_bitwidth > spec_bitwidth: + raise compressor.CompressionError( + f"index_bitwidth too small: {actual_bitwidth} bits needed, " + f"but only {spec_bitwidth} specified") + elif actual_bitwidth < spec_bitwidth: + print( + f"warning: index_bitwidth larger than necessary: only " + f"{actual_bitwidth} bits needed, but {spec_bitwidth} specified", + file=sys.stderr) + + # Pack indices into bytes + encoded_data = pack_indices(compressed.indices, spec_bitwidth) + + # Pack value tables + table_len = max(len(t) for t in compressed.lookup_tables) + value_tables_bytes = pack_lookup_tables(compressed.lookup_tables, + table_len) + + # Build ancillary data + lut_data = LutAncillaryData( + lut_version=1, + bitwidth=spec_bitwidth, + value_table_stride=table_len, + value_tables=value_tables_bytes, + ) + + # Build complete ancillary data tensor bytes: DCM header + value tables + dcm = decode.DecodeCommonMetadata( + decode_type=self.decode_type, + user_data=lut_data.to_user_data(), + ) + ancillary_data = dcm.to_bytes() + lut_data.to_bytes() + + return compressor.CompressionResult( + encoded_data=encoded_data, + ancillary_data=ancillary_data, + ) diff --git a/tensorflow/lite/micro/compression/lut_test.py b/tensorflow/lite/micro/compression/lut_test.py new file mode 100644 index 00000000000..d01dcfd4260 --- /dev/null +++ b/tensorflow/lite/micro/compression/lut_test.py @@ -0,0 +1,405 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for LUT compression plugin.""" + +import numpy as np +import unittest + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import lut +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +class TestCompressArray(unittest.TestCase): + """Tests for the compress_array function.""" + + def test_per_tensor_basic(self): + """Per-tensor compression extracts unique values.""" + array = np.array([1, 2, 1, 2, 3, 3], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + + self.assertIsNone(compressed.compression_axis) + self.assertEqual(len(compressed.lookup_tables), 1) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1, 2, 3]) + # Indices should map back to original values + reconstructed = compressed.lookup_tables[0][compressed.indices] + np.testing.assert_array_equal(reconstructed, array) + + def test_per_tensor_preserves_shape(self): + """Indices array has same shape as input.""" + # yapf: disable + array = np.array([[1, 2], + [3, 1], + [2, 3]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=None) + + self.assertEqual(compressed.indices.shape, array.shape) + + def test_per_channel_axis0(self): + """Per-channel compression along axis 0.""" + # Each row gets its own value table + # yapf: disable + array = np.array([[1, 1, 1], + [5, 5, 5], + [9, 9, 9]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=0) + + self.assertEqual(compressed.compression_axis, 0) + self.assertEqual(len(compressed.lookup_tables), 3) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1]) + np.testing.assert_array_equal(compressed.lookup_tables[1], [5]) + np.testing.assert_array_equal(compressed.lookup_tables[2], [9]) + + def test_per_channel_axis1(self): + """Per-channel compression along axis 1.""" + # Each column gets its own value table + # yapf: disable + array = np.array([[1, 5], + [1, 5], + [1, 5]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=1) + + self.assertEqual(compressed.compression_axis, 1) + self.assertEqual(len(compressed.lookup_tables), 2) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1]) + np.testing.assert_array_equal(compressed.lookup_tables[1], [5]) + + def test_single_value(self): + """Array with single unique value.""" + array = np.array([7, 7, 7, 7], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + + self.assertEqual(len(compressed.lookup_tables), 1) + np.testing.assert_array_equal(compressed.lookup_tables[0], [7]) + np.testing.assert_array_equal(compressed.indices, [0, 0, 0, 0]) + + def test_bitwidth_calculation(self): + """Index bitwidth is computed correctly.""" + # 3 unique values -> 2 bits needed + array = np.array([0, 1, 2], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 2) + + # 4 unique values -> 2 bits needed + array = np.array([0, 1, 2, 3], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 2) + + # 5 unique values -> 3 bits needed + array = np.array([0, 1, 2, 3, 4], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 3) + + def test_bitwidth_single_value(self): + """Single unique value requires 1 bit.""" + array = np.array([42, 42, 42], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 1) + + +class TestPackIndices(unittest.TestCase): + """Tests for the pack_indices function.""" + + def test_4bit_packing(self): + """Pack indices into 4-bit fields.""" + indices = np.array([1, 2, 3, 0]) + result = lut.pack_indices(indices, bitwidth=4) + # Big-endian: 0001 0010 | 0011 0000 = 0x12 0x30 + self.assertEqual(result, bytes([0x12, 0x30])) + + def test_2bit_packing(self): + """Pack indices into 2-bit fields.""" + indices = np.array([0, 1, 2, 3]) + result = lut.pack_indices(indices, bitwidth=2) + # Big-endian: 00 01 10 11 = 0x1B + self.assertEqual(result, bytes([0x1B])) + + def test_3bit_packing(self): + """Pack indices into 3-bit fields.""" + indices = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + result = lut.pack_indices(indices, bitwidth=3) + # 000 001 010 011 | 100 101 110 111 + # 00000101 | 00111001 | 01110111 = 0x05 0x39 0x77 + self.assertEqual(result, bytes([0x05, 0x39, 0x77])) + + def test_1bit_packing(self): + """Pack indices into 1-bit fields.""" + indices = np.array([0, 1, 0, 1, 1, 0, 1, 0]) + result = lut.pack_indices(indices, bitwidth=1) + # 0 1 0 1 1 0 1 0 = 0x5A + self.assertEqual(result, bytes([0x5A])) + + def test_multidimensional_flattens(self): + """Multidimensional indices are flattened row-major.""" + # yapf: disable + indices = np.array([[0, 1], + [2, 3]]) + # yapf: enable + result = lut.pack_indices(indices, bitwidth=4) + # 0000 0001 | 0010 0011 = 0x01 0x23 + self.assertEqual(result, bytes([0x01, 0x23])) + + +class TestPackLookupTables(unittest.TestCase): + """Tests for the pack_lookup_tables function.""" + + def test_single_table_int8(self): + """Pack single INT8 lookup table.""" + tables = [np.array([10, 20, 30], dtype=np.int8)] + result = lut.pack_lookup_tables(tables, table_len=4) + # Values: 10, 20, 30, 0 (padding) + self.assertEqual(result, bytes([10, 20, 30, 0])) + + def test_multiple_tables(self): + """Pack multiple lookup tables.""" + tables = [ + np.array([1, 2], dtype=np.int8), + np.array([3, 4], dtype=np.int8), + ] + result = lut.pack_lookup_tables(tables, table_len=4) + # Table 1: 1, 2, 0, 0 | Table 2: 3, 4, 0, 0 + self.assertEqual(result, bytes([1, 2, 0, 0, 3, 4, 0, 0])) + + def test_int16_little_endian(self): + """INT16 values are packed in native byte order.""" + tables = [np.array([0x1234, 0x5678], dtype=' Date: Sun, 24 May 2026 23:24:04 -0500 Subject: [PATCH 07/19] feat(compression): add Huffman and Pruning compression support Add spec types, YAML parser support, and plugin stubs for Huffman and Pruning compression methods. The plugins raise CompressionError when invoked, to be replaced with working implementations later. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 22 +++++++ tensorflow/lite/micro/compression/huffman.py | 60 ++++++++++++++++++++ tensorflow/lite/micro/compression/pruning.py | 59 +++++++++++++++++++ tensorflow/lite/micro/compression/spec.py | 51 +++++++++++++++-- 4 files changed, 186 insertions(+), 6 deletions(-) create mode 100644 tensorflow/lite/micro/compression/huffman.py create mode 100644 tensorflow/lite/micro/compression/pruning.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index f7f411f196f..375a42d7a49 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -304,6 +304,28 @@ tflm_py_test( ], ) +tflm_py_library( + name = "huffman", + srcs = ["huffman.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + ], +) + +tflm_py_library( + name = "pruning", + srcs = ["pruning.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/huffman.py b/tensorflow/lite/micro/compression/huffman.py new file mode 100644 index 00000000000..40d0be9284a --- /dev/null +++ b/tensorflow/lite/micro/compression/huffman.py @@ -0,0 +1,60 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Huffman compression plugin (stub). + +This module provides a placeholder for Huffman compression using Xtensa-format +decode tables. The actual implementation is not yet available. + +Supported tensor types (when implemented): INT8, INT16 +""" + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class HuffmanCompressor: + """Huffman compression plugin (stub). + + This stub exists to validate the plugin architecture. The actual Huffman + compression algorithm using Xtensa-format decode tables is not yet + implemented. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.HUFFMAN.""" + return decode.DecodeType.HUFFMAN + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using Huffman encoding. + + Args: + tensor: The tensor to compress. + method: Must be a HuffmanCompression instance. + + Returns: + CompressionResult (not implemented). + + Raises: + CompressionError: Always, since this is a stub. + """ + raise compressor.CompressionError( + "Huffman compression not yet implemented. " + "This stub exists to validate the plugin architecture.") diff --git a/tensorflow/lite/micro/compression/pruning.py b/tensorflow/lite/micro/compression/pruning.py new file mode 100644 index 00000000000..2181b73e34a --- /dev/null +++ b/tensorflow/lite/micro/compression/pruning.py @@ -0,0 +1,59 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pruning compression plugin (stub). + +This module provides a placeholder for pruning (sparsity) compression. +The actual implementation is not yet available. + +Supported tensor types (when implemented): All TFLM tensor types +""" + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class PruningCompressor: + """Pruning compression plugin (stub). + + This stub exists to validate the plugin architecture. The actual pruning + compression algorithm for sparse tensors is not yet implemented. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.PRUNING.""" + return decode.DecodeType.PRUNING + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using pruning (sparsity) encoding. + + Args: + tensor: The tensor to compress. + method: Must be a PruningCompression instance. + + Returns: + CompressionResult (not implemented). + + Raises: + CompressionError: Always, since this is a stub. + """ + raise compressor.CompressionError( + "Pruning compression not yet implemented. " + "This stub exists to validate the plugin architecture.") diff --git a/tensorflow/lite/micro/compression/spec.py b/tensorflow/lite/micro/compression/spec.py index 6f782e92d7a..5c0f81885bc 100644 --- a/tensorflow/lite/micro/compression/spec.py +++ b/tensorflow/lite/micro/compression/spec.py @@ -58,10 +58,32 @@ class Tensor: @dataclass class LookUpTableCompression(CompressionMethod): + """LUT compression using lookup tables. + Attributes: + index_bitwidth: Number of bits per index (1-7). + """ index_bitwidth: int +@dataclass +class HuffmanCompression(CompressionMethod): + """Huffman compression using Xtensa-format decode tables. + + Supported tensor types: INT8, INT16 only. + """ + pass + + +@dataclass +class PruningCompression(CompressionMethod): + """Pruning (sparsity) compression. + + Supported tensor types: All TFLM tensor types. + """ + pass + + class ParseError(Exception): "Raised when the spec string cannot be parsed." @@ -70,6 +92,18 @@ def __init__(self, message="error parsing spec", wrapped_exception=None): self.original_exception = wrapped_exception +def _parse_compression_method(comp: dict) -> CompressionMethod: + """Parse a single compression method from YAML dict.""" + if "lut" in comp: + return LookUpTableCompression(index_bitwidth=comp["lut"]["index_bitwidth"]) + elif "huffman" in comp: + return HuffmanCompression() + elif "pruning" in comp: + return PruningCompression() + else: + raise ParseError(f"Unknown compression method: {list(comp.keys())}") + + def parse_yaml(y: str) -> list[Tensor]: "Parses a compression spec in a YAML string into its Python representation." try: @@ -77,14 +111,19 @@ def parse_yaml(y: str) -> list[Tensor]: tensors = [] for item in config["tensors"]: - bitwidth = item["compression"][0]["lut"]["index_bitwidth"] - tensor = Tensor(subgraph=item["subgraph"], - tensor=item["tensor"], - compression=[ - LookUpTableCompression(index_bitwidth=bitwidth), - ]) + methods = [] + for comp in item["compression"]: + methods.append(_parse_compression_method(comp)) + + tensor = Tensor( + subgraph=item["subgraph"], + tensor=item["tensor"], + compression=methods, + ) tensors.append(tensor) + except ParseError: + raise except Exception as e: raise ParseError() from e From 6a81b9844658eb5a532b14aab1d4c64d575dfd7e Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:26:33 -0500 Subject: [PATCH 08/19] feat(python): add alt decompression memory parameter to interpreter Add alt_decompression_memory_size parameter to the Python interpreter API. When non-zero, allocates a separate memory region for DECODE operator outputs and calls SetDecompressionMemory before AllocateTensors. BUG=part of #3256 --- python/tflite_micro/_runtime.cc | 9 +++++---- python/tflite_micro/interpreter_wrapper.cc | 18 ++++++++++++++++-- python/tflite_micro/interpreter_wrapper.h | 6 +++++- python/tflite_micro/python_ops_resolver.cc | 2 ++ python/tflite_micro/runtime.py | 12 ++++++++++++ 5 files changed, 40 insertions(+), 7 deletions(-) diff --git a/python/tflite_micro/_runtime.cc b/python/tflite_micro/_runtime.cc index 246545fd016..53825f14f0d 100644 --- a/python/tflite_micro/_runtime.cc +++ b/python/tflite_micro/_runtime.cc @@ -33,10 +33,11 @@ PYBIND11_MODULE(_runtime, m) { .def(py::init([](const py::bytes& data, const std::vector& registerers_by_name, size_t arena_size, int num_resource_variables, - tflite::InterpreterConfig config) { - return std::unique_ptr( - new InterpreterWrapper(data.ptr(), registerers_by_name, arena_size, - num_resource_variables, config)); + tflite::InterpreterConfig config, + size_t alt_decompression_memory_size) { + return std::unique_ptr(new InterpreterWrapper( + data.ptr(), registerers_by_name, arena_size, num_resource_variables, + config, alt_decompression_memory_size)); })) .def("PrintAllocations", &InterpreterWrapper::PrintAllocations) .def("Invoke", &InterpreterWrapper::Invoke) diff --git a/python/tflite_micro/interpreter_wrapper.cc b/python/tflite_micro/interpreter_wrapper.cc index 669589890ad..c74ab84736b 100644 --- a/python/tflite_micro/interpreter_wrapper.cc +++ b/python/tflite_micro/interpreter_wrapper.cc @@ -238,7 +238,14 @@ InterpreterWrapper::~InterpreterWrapper() { InterpreterWrapper::InterpreterWrapper( PyObject* model_data, const std::vector& registerers_by_name, - size_t arena_size, int num_resource_variables, InterpreterConfig config) { + size_t arena_size, int num_resource_variables, InterpreterConfig config, + size_t alt_decompression_memory_size) + : memory_arena_(new uint8_t[arena_size]), + alt_decompression_memory_(alt_decompression_memory_size > 0 + ? new uint8_t[alt_decompression_memory_size] + : nullptr), + alt_decompression_region_{alt_decompression_memory_.get(), + alt_decompression_memory_size} { interpreter_ = nullptr; // `model_data` is used as a raw pointer beyond the scope of this @@ -266,7 +273,6 @@ InterpreterWrapper::InterpreterWrapper( "--//:with_compression=true to enable compression support."); } - memory_arena_ = std::unique_ptr(new uint8_t[arena_size]); for (const std::string& registerer : registerers_by_name) { if (!AddCustomOpRegistererByName(registerer.c_str(), &python_ops_resolver_)) { @@ -296,6 +302,14 @@ InterpreterWrapper::InterpreterWrapper( interpreter_ = new MicroInterpreter(model, python_ops_resolver_, allocator_, resource_variables_); + if (alt_decompression_memory_size > 0) { + TfLiteStatus status = + interpreter_->SetDecompressionMemory(&alt_decompression_region_, 1); + if (status != kTfLiteOk) { + ThrowRuntimeError("TFLM failed to set decompression memory"); + } + } + TfLiteStatus status = interpreter_->AllocateTensors(); if (status != kTfLiteOk) { ThrowRuntimeError("TFLM failed to allocate tensors"); diff --git a/python/tflite_micro/interpreter_wrapper.h b/python/tflite_micro/interpreter_wrapper.h index 9bb31b067fe..d3a156b337a 100644 --- a/python/tflite_micro/interpreter_wrapper.h +++ b/python/tflite_micro/interpreter_wrapper.h @@ -19,6 +19,7 @@ limitations under the License. #include "python/tflite_micro/python_ops_resolver.h" #include "tensorflow/lite/micro/micro_allocator.h" +#include "tensorflow/lite/micro/micro_context.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/recording_micro_allocator.h" @@ -40,7 +41,8 @@ class InterpreterWrapper { InterpreterWrapper( PyObject* model_data, const std::vector& registerers_by_name, size_t arena_size, int num_resource_variables, - InterpreterConfig config = InterpreterConfig::kAllocationRecording); + InterpreterConfig config = InterpreterConfig::kAllocationRecording, + size_t alt_decompression_memory_size = 0); ~InterpreterWrapper(); void PrintAllocations(); @@ -57,6 +59,8 @@ class InterpreterWrapper { tflite::RecordingMicroAllocator* recording_allocator_ = nullptr; const PyObject* model_; std::unique_ptr memory_arena_; + std::unique_ptr alt_decompression_memory_; + tflite::MicroContext::AlternateMemoryRegion alt_decompression_region_; tflite::PythonOpsResolver python_ops_resolver_; tflite::MicroInterpreter* interpreter_; }; diff --git a/python/tflite_micro/python_ops_resolver.cc b/python/tflite_micro/python_ops_resolver.cc index 5f7d40fb74e..19f324bdf2f 100644 --- a/python/tflite_micro/python_ops_resolver.cc +++ b/python/tflite_micro/python_ops_resolver.cc @@ -40,7 +40,9 @@ PythonOpsResolver::PythonOpsResolver() { AddConv2D(); AddCos(); AddCumSum(); +#ifdef USE_TFLM_COMPRESSION AddDecode(); +#endif AddDelay(); AddDepthToSpace(); AddDepthwiseConv2D(); diff --git a/python/tflite_micro/runtime.py b/python/tflite_micro/runtime.py index d895f8c4993..7052972b4a6 100644 --- a/python/tflite_micro/runtime.py +++ b/python/tflite_micro/runtime.py @@ -100,6 +100,7 @@ def __init__( custom_op_registerers, arena_size, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): if model_data is None: raise ValueError("Model must not be None") @@ -122,6 +123,7 @@ def __init__( arena_size, num_resource_variables, _ENUM_TRANSLATOR[intrepreter_config], + alt_decompression_memory_size, ) @classmethod @@ -131,6 +133,7 @@ def from_file( custom_op_registerers=[], arena_size=None, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): """Instantiates a TFLM interpreter from a model .tflite filepath. @@ -140,6 +143,9 @@ def from_file( custom OP registerer arena_size: Tensor arena size in bytes. If unused, tensor arena size will default to 10 times the model size. + alt_decompression_memory_size: Size in bytes of alternate decompression + memory. If non-zero, DECODE operators will use this memory instead of + the main arena for decompressed tensor outputs. Returns: An Interpreter instance @@ -155,6 +161,7 @@ def from_file( custom_op_registerers, arena_size, intrepreter_config, + alt_decompression_memory_size, ) @classmethod @@ -164,6 +171,7 @@ def from_bytes( custom_op_registerers=[], arena_size=None, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): """Instantiates a TFLM interpreter from a model in byte array. @@ -173,6 +181,9 @@ def from_bytes( custom OP registerer arena_size: Tensor arena size in bytes. If unused, tensor arena size will default to 10 times the model size. + alt_decompression_memory_size: Size in bytes of alternate decompression + memory. If non-zero, DECODE operators will use this memory instead of + the main arena for decompressed tensor outputs. Returns: An Interpreter instance @@ -183,6 +194,7 @@ def from_bytes( custom_op_registerers, arena_size, intrepreter_config, + alt_decompression_memory_size, ) def print_allocations(self): From 1d2f9c481eafc844f0e775c16e35974b4805cbb1 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:28:09 -0500 Subject: [PATCH 09/19] feat(compression): add DECODE operator insertion Insert DECODE operators before consumers of compressed tensors. Each consumer gets its own DECODE operator to support alternate decompression memory, which resets allocations between DECODE invocations. After insertion, compressed tensors are rewritten to hold encoded data as UINT8 with shape matching byte count. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 29 ++ .../lite/micro/compression/decode_insert.py | 268 +++++++++++ .../micro/compression/decode_insert_test.py | 417 ++++++++++++++++++ 3 files changed, 714 insertions(+) create mode 100644 tensorflow/lite/micro/compression/decode_insert.py create mode 100644 tensorflow/lite/micro/compression/decode_insert_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 375a42d7a49..7d3291060fd 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -326,6 +326,35 @@ tflm_py_library( ], ) +tflm_py_library( + name = "decode_insert", + srcs = ["decode_insert.py"], + deps = [ + ":compressor", + ":model_editor", + "//tensorflow/lite/python:schema_py", + ], +) + +tflm_py_test( + name = "decode_insert_test", + size = "small", + srcs = ["decode_insert_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + deps = [ + ":compressor", + ":decode", + ":decode_insert", + ":model_editor", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + tflm_py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/decode_insert.py b/tensorflow/lite/micro/compression/decode_insert.py new file mode 100644 index 00000000000..43dffce46f0 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_insert.py @@ -0,0 +1,268 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DECODE operator insertion into TFLite model graphs. + +This module inserts DECODE operators into a compressed model. DECODE operators +transform encoded tensors (with their paired ancillary data tensors) into +tensors ready for use by downstream operators. + +The DECODE operator is registered as a custom operator named "TFLM_DECODE". +Each DECODE output requires two inputs: the encoded tensor and the ancillary +data tensor (containing the DCM header and decode-type-specific data). +""" + +import warnings +from collections import defaultdict +from dataclasses import dataclass +from typing import Optional + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + +# Custom operator name for DECODE +DECODE_CUSTOM_OP_NAME = "TFLM_DECODE" + + +@dataclass +class _CompressedTensorInfo: + """Information about a compressed tensor for DECODE insertion.""" + subgraph_idx: int + tensor_idx: int + tensor: model_editor.Tensor + encoded_data: bytes + ancillary_data: bytes + consumers: list[model_editor.Operator] + + +def _find_tensor_consumers( + subgraph: model_editor.Subgraph, + tensor: model_editor.Tensor, +) -> list[model_editor.Operator]: + """Find all operators in subgraph that use tensor as an input.""" + consumers = [] + for op in subgraph.operators: + if tensor in op.inputs: + consumers.append(op) + return consumers + + +def _create_ancillary_tensor( + ancillary_data: bytes, + original_tensor: model_editor.Tensor, +) -> model_editor.Tensor: + """Create an ancillary data tensor for a compressed tensor. + + Args: + ancillary_data: The complete ancillary data (DCM + type-specific data). + original_tensor: The original tensor being decoded, for naming. + + Returns: + A new Tensor containing the ancillary data. + """ + name = None + if original_tensor.name: + name = f"{original_tensor.name}_ancillary" + + return model_editor.Tensor( + shape=(len(ancillary_data), ), + dtype=tflite.TensorType.UINT8, + data=ancillary_data, + name=name, + ) + + +def _create_output_tensor( + original_tensor: model_editor.Tensor, ) -> model_editor.Tensor: + """Create the output tensor for a DECODE operator. + + The output tensor has the same shape, dtype, and quantization as the + original tensor would have when decoded. It has no data---the DECODE + operator produces its values at runtime. + + Args: + original_tensor: The original tensor being decoded. + + Returns: + A new Tensor for the DECODE output. + """ + name = None + if original_tensor.name: + name = f"{original_tensor.name}_decoded" + + return model_editor.Tensor( + shape=original_tensor.shape, + dtype=original_tensor.dtype, + quantization=original_tensor.quantization, + name=name, + ) + + +def _rewire_consumers( + consumers: list[model_editor.Operator], + old_tensor: model_editor.Tensor, + new_tensor: model_editor.Tensor, +) -> None: + """Replace old_tensor with new_tensor in all consumer inputs.""" + for consumer in consumers: + consumer.inputs = [ + new_tensor if t is old_tensor else t for t in consumer.inputs + ] + + +def _rewrite_encoded_tensor( + tensor: model_editor.Tensor, + encoded_data: bytes, +) -> None: + """Rewrite a compressed tensor to hold encoded data. + + The original tensor contained uncompressed values with quantization. After + compression, it holds packed indices (or other encoded form) as raw bytes. + This function updates the tensor in place to reflect its new role. + + Args: + tensor: The tensor to rewrite. + encoded_data: The compressed/encoded data bytes. + """ + tensor.shape = (len(encoded_data), ) + tensor.dtype = tflite.TensorType.UINT8 + tensor.quantization = None + tensor.buffer.data = encoded_data + + +def insert_decode_operators( + model: model_editor.Model, + compression_results: dict[tuple[int, int], compressor.CompressionResult], +) -> None: + """Insert DECODE operators for all compressed tensors. + + This function modifies the model in-place, inserting DECODE operators + before any operator that uses a compressed tensor as input. + + A separate DECODE is inserted before each consumer, rather than sharing one + DECODE output among all consumers. This is required because the interpreter's + alternate decompression memory resets its allocation offset for each DECODE's + Prepare, causing all DECODE outputs to be allocated at the same address. If + two consumers share one DECODE and another DECODE runs between them, the + intervening DECODE overwrites the shared output, corrupting data for the + second consumer. + + For each consumer of a compressed tensor: + 1. Create an ancillary data tensor containing DCM + type-specific data + 2. Create an output tensor with the same shape/dtype as the decoded tensor + 3. Insert a DECODE operator immediately before the consumer + 4. Rewire the consumer to use the DECODE output + + Args: + model: The model to modify in-place. + compression_results: Map from (subgraph_idx, tensor_idx) to the + CompressionResult containing ancillary_data. + """ + # Group compressed tensors by subgraph + by_subgraph: dict[int, list[_CompressedTensorInfo]] = defaultdict(list) + + for (sg_idx, tensor_idx), result in compression_results.items(): + subgraph = model.subgraphs[sg_idx] + tensor = subgraph.tensors[tensor_idx] + consumers = _find_tensor_consumers(subgraph, tensor) + + if not consumers: + # Check if tensor is a subgraph output + is_output = tensor in subgraph.outputs + if is_output: + # TODO: Handle compressed tensors that are subgraph outputs. + # This occurs in multi-subgraph models using IF/WHILE where a + # compressed tensor flows out of a subgraph. + raise NotImplementedError( + f"Compressed tensor {tensor.name!r} (subgraph {sg_idx}, " + f"tensor {tensor_idx}) is a subgraph output with no consumers. " + "Compressed subgraph outputs are not yet supported.") + else: + warnings.warn( + f"Compressed tensor {tensor.name!r} (subgraph {sg_idx}, " + f"tensor {tensor_idx}) has no consumers and is not a subgraph " + "output. No DECODE operator will be inserted.", + stacklevel=2) + continue + + info = _CompressedTensorInfo( + subgraph_idx=sg_idx, + tensor_idx=tensor_idx, + tensor=tensor, + encoded_data=result.encoded_data, + ancillary_data=result.ancillary_data, + consumers=consumers, + ) + by_subgraph[sg_idx].append(info) + + # Process each subgraph + for sg_idx, tensor_infos in by_subgraph.items(): + subgraph = model.subgraphs[sg_idx] + + # Collect all (consumer, tensor_info) pairs and sort by consumer position + # in reverse order so insertions don't invalidate positions + consumer_pairs = [] + for info in tensor_infos: + for consumer in info.consumers: + consumer_pairs.append((consumer, info)) + + consumer_pairs.sort( + key=lambda pair: subgraph.operators.index(pair[0]), + reverse=True, + ) + + # Cache ancillary tensors by original tensor to avoid duplicates. Each + # DECODE needs its own output tensor, but ancillary data is identical for + # all DECODEs of the same compressed tensor. + ancillary_cache: dict[model_editor.Tensor, model_editor.Tensor] = {} + + # Track tensors to rewrite after all output tensors are created, since + # _create_output_tensor reads the original tensor's shape/dtype/quantization. + tensors_to_rewrite: dict[model_editor.Tensor, bytes] = {} + + for consumer, info in consumer_pairs: + # Reuse or create ancillary data tensor + if info.tensor not in ancillary_cache: + ancillary_tensor = _create_ancillary_tensor( + info.ancillary_data, + info.tensor, + ) + subgraph.tensors.append(ancillary_tensor) + ancillary_cache[info.tensor] = ancillary_tensor + tensors_to_rewrite[info.tensor] = info.encoded_data + else: + ancillary_tensor = ancillary_cache[info.tensor] + + # Create output tensor (one per DECODE) + output_tensor = _create_output_tensor(info.tensor) + subgraph.tensors.append(output_tensor) + + # Create DECODE operator + decode_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CUSTOM, + custom_code=DECODE_CUSTOM_OP_NAME, + inputs=[info.tensor, ancillary_tensor], + outputs=[output_tensor], + ) + + # Insert DECODE immediately before this consumer + insert_pos = subgraph.operators.index(consumer) + subgraph.operators.insert(insert_pos, decode_op) + + # Rewire only this consumer to use the decoded output + _rewire_consumers([consumer], info.tensor, output_tensor) + + # Rewrite encoded tensors after all output tensors are created + for tensor, encoded_data in tensors_to_rewrite.items(): + _rewrite_encoded_tensor(tensor, encoded_data) diff --git a/tensorflow/lite/micro/compression/decode_insert_test.py b/tensorflow/lite/micro/compression/decode_insert_test.py new file mode 100644 index 00000000000..11be81963d9 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_insert_test.py @@ -0,0 +1,417 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for DECODE operator insertion.""" + +import warnings + +import numpy as np +import unittest + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +def _build_simple_fc_model(): + """Build a simple model with one FC operator and compressible weights.""" + # yapf: disable + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]], dtype=np.int8), + name="weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + # yapf: enable + input_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + return model + + +def _build_shared_weights_model(): + """Build model where one tensor is used by multiple operators.""" + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="shared_weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + input1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input1", + ) + input2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input2", + ) + output1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output1", + ) + output2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output2", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input1, weights], + outputs=[output1], + ), + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input2, weights], + outputs=[output2], + ), + ], + ) + ]) + return model + + +def _make_dummy_ancillary_data() -> bytes: + """Create dummy ancillary data for testing.""" + dcm = decode.DecodeCommonMetadata( + decode_type=decode.DecodeType.LUT, + user_data=b'\x01\x04\x10' + b'\x00' * 9, # lut_version, bitwidth, stride + ) + value_tables = bytes([1, 2, 3, 4] + [0] * 12) # 16-byte padded table + return dcm.to_bytes() + value_tables + + +class TestDecodeInsertion(unittest.TestCase): + """Tests for insert_decode_operators function.""" + + def test_insert_single_decode_operator(self): + """DECODE operator inserted before FC that uses compressed weights.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + # Create compression result + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + # Insert DECODE operators + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # Should have 2 operators: DECODE then FC + self.assertEqual(len(sg.operators), 2) + self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[0].custom_code, + decode_insert.DECODE_CUSTOM_OP_NAME) + self.assertEqual(sg.operators[1].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + def test_decode_inputs_structure(self): + """DECODE operator has correct inputs: encoded tensor + ancillary.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + + # DECODE has 2 inputs + self.assertEqual(len(decode_op.inputs), 2) + # First input is the encoded tensor (original weights) + self.assertIs(decode_op.inputs[0], weights_tensor) + # Second input is ancillary tensor + self.assertEqual(decode_op.inputs[1].dtype, tflite.TensorType.UINT8) + + def test_decode_output_structure(self): + """DECODE operator output has correct shape and dtype.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + # Save original properties before rewrite + original_shape = weights_tensor.shape + original_dtype = weights_tensor.dtype + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + output = decode_op.outputs[0] + + # Output matches original (pre-rewrite) tensor shape and dtype + self.assertEqual(output.shape, original_shape) + self.assertEqual(output.dtype, original_dtype) + + def test_consumer_rewired_to_decode_output(self): + """FC operator input rewired to use DECODE output.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + fc_op = model.subgraphs[0].operators[1] + + # FC's second input (weights) should now be DECODE's output + self.assertIs(fc_op.inputs[1], decode_op.outputs[0]) + # Original weights tensor should NOT be in FC inputs + self.assertNotIn(weights_tensor, fc_op.inputs) + + def test_shared_tensor_decode_per_consumer(self): + """Tensor used by multiple ops gets separate DECODE for each consumer.""" + model = _build_shared_weights_model() + weights_tensor = model.subgraphs[0].tensor_by_name("shared_weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # Should have 4 operators: 2 DECODEs + 2 FCs (DECODE before each FC) + self.assertEqual(len(sg.operators), 4) + self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[1].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + self.assertEqual(sg.operators[2].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[3].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + decode_op1 = sg.operators[0] + fc_op1 = sg.operators[1] + decode_op2 = sg.operators[2] + fc_op2 = sg.operators[3] + + # Each FC should use its own DECODE's output + self.assertIs(fc_op1.inputs[1], decode_op1.outputs[0]) + self.assertIs(fc_op2.inputs[1], decode_op2.outputs[0]) + # The two DECODEs should have different outputs + self.assertIsNot(decode_op1.outputs[0], decode_op2.outputs[0]) + # The two DECODEs should share the same ancillary tensor + self.assertIs(decode_op1.inputs[1], decode_op2.inputs[1]) + + def test_ancillary_tensor_contains_dcm(self): + """Ancillary tensor data contains valid DCM header.""" + model = _build_simple_fc_model() + + ancillary_data = _make_dummy_ancillary_data() + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=ancillary_data, + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + ancillary_tensor = decode_op.inputs[1] + + # Ancillary tensor data should match what we provided + self.assertEqual(bytes(ancillary_tensor.array), ancillary_data) + + # Verify DCM header + dcm_bytes = ancillary_tensor.array[:16] + self.assertEqual(dcm_bytes[0], 0) # decode_type = LUT + self.assertEqual(dcm_bytes[1], 1) # DCM version + + def test_no_consumers_no_decode(self): + """Tensor with no consumers gets no DECODE operator and emits warning.""" + # Create model where compressed tensor is not used as input + unused_tensor = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="unused", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + input_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output", + ) + other_weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="other_weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[unused_tensor, other_weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, other_weights], + outputs=[output_t], + ) + ], + ) + ]) + + # Compress the unused tensor + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + decode_insert.insert_decode_operators(model, compression_results) + + # Should emit a warning about no consumers + self.assertEqual(len(w), 1) + self.assertIn("no consumers", str(w[0].message)) + self.assertIn("unused", str(w[0].message)) + + # Should still have just 1 operator (no DECODE inserted) + self.assertEqual(len(model.subgraphs[0].operators), 1) + + def test_tensor_naming(self): + """Output and ancillary tensors get appropriate names.""" + model = _build_simple_fc_model() + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + ancillary = decode_op.inputs[1] + output = decode_op.outputs[0] + + self.assertEqual(ancillary.name, "weights_ancillary") + self.assertEqual(output.name, "weights_decoded") + + def test_encoded_tensor_rewritten(self): + """Compressed tensor is rewritten with encoded data, UINT8 type, no quant.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + encoded_data = b'\xAB\xCD\xEF' + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=encoded_data, + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + # Original tensor should be rewritten + self.assertEqual(weights_tensor.shape, (len(encoded_data), )) + self.assertEqual(weights_tensor.dtype, tflite.TensorType.UINT8) + self.assertIsNone(weights_tensor.quantization) + self.assertEqual(weights_tensor.buffer.data, encoded_data) + + +class TestHelperFunctions(unittest.TestCase): + """Tests for internal helper functions.""" + + def test_find_tensor_consumers(self): + """_find_tensor_consumers finds all ops using a tensor.""" + model = _build_shared_weights_model() + sg = model.subgraphs[0] + weights = sg.tensor_by_name("shared_weights") + + consumers = decode_insert._find_tensor_consumers(sg, weights) + + self.assertEqual(len(consumers), 2) + + +if __name__ == "__main__": + unittest.main() From 6c05fb9c2f5c5f7be19063804a984deb66661f53 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:30:48 -0500 Subject: [PATCH 10/19] refactor(compression): use plugin architecture in compress.py Replace monolithic compression logic with a dispatch table that routes compression requests to plugin modules based on the spec's compression method type. After compressing tensors, insert DECODE operators into the model graph. Warn when compression expands data, helping users identify tensors that don't benefit from compression. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 13 +- tensorflow/lite/micro/compression/compress.py | 310 ++------ .../lite/micro/compression/compress_test.py | 700 ++++++++---------- 3 files changed, 389 insertions(+), 634 deletions(-) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 7d3291060fd..3478d584b56 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -123,14 +123,15 @@ py_library( "compress.py", ], deps = [ - ":metadata_py", + ":compressor", + ":decode_insert", + ":huffman", + ":lut", ":model_editor", + ":pruning", ":spec", "//tensorflow/lite/micro/tools:tflite_flatbuffer_align", requirement("absl_py"), - requirement("flatbuffers"), - requirement("bitarray"), - requirement("numpy"), ], ) @@ -159,11 +160,11 @@ py_test( target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ ":compress", - ":metadata_py", + ":compressor", + ":decode_insert", ":model_editor", ":spec", "//tensorflow/lite/python:schema_py", - requirement("bitarray"), requirement("numpy"), ], ) diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index b6d5aef4435..270951fecf8 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -16,22 +16,22 @@ See USAGE. """ -import bitarray -import bitarray.util -from dataclasses import dataclass, field import os import sys import tempfile -from typing import ByteString, Iterable, Optional +import warnings +from typing import ByteString, Iterable, Type import absl.app import absl.flags -import flatbuffers -import numpy as np +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import huffman +from tflite_micro.tensorflow.lite.micro.compression import lut from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import pruning from tflite_micro.tensorflow.lite.micro.compression import spec -from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper USAGE = f"""\ @@ -49,221 +49,48 @@ {spec.EXAMPLE_YAML_SPEC} --- -The only compression method currently implemented is "lut", i.e., -Look-Up-Table. This method requires the tensor in the input model to have a -small number of unique values, fewer than or equal to 2**index_bitwidth. LUT -compression collects these values into a lookup table, and rewrites the tensor -as bitwidth-wide integer indices into that lookup table. Presumably, the input -model has been trained or preprocessed in a way that the tensor values -are binned into a meaningful, limited set. -""" - -# A compressed model augments the usual .tflite flatbuffer with a flatbuffer of -# its own containing compression metadata, stored at the buffer index stored at -# the following key in the .tflite flatbuffer's metadata map. -TFLITE_METADATA_KEY = "COMPRESSION_METADATA" - - -class CompressionError(Exception): - """Raised when compression fails for the reason documented in the message.""" - - def __init__(self, message, wrapped_exception=None): - super().__init__(f"{message}: {str(wrapped_exception)}") - self.original_exception = wrapped_exception - - -class _MetadataBuilder: - """Builder for the compression metadata flatbuffer.""" - - def __init__(self): - self._metadata = schema.MetadataT() - self._metadata.subgraphs = [] - - def compile(self) -> bytearray: - """Packs the metadata into a binary array and returns it. - """ - builder = flatbuffers.Builder(1 * 2**10) - root = self._metadata.Pack(builder) - builder.Finish(root) - return builder.Output() - - def subgraph(self, index: int): - """Return subgraph at index, adding subgraphs if necessary. - """ - while len(self._metadata.subgraphs) <= index: - self._add_subgraph() - return self._metadata.subgraphs[index] - - def add_lut_tensor(self, subgraph_id: int): - """Add LUT tensor to the given subgraph and return it. - """ - tensor = schema.LutTensorT() - self.subgraph(subgraph_id).lutTensors.append(tensor) - return tensor - - def _add_subgraph(self): - subgraph = schema.SubgraphT() - subgraph.lutTensors = [] - self._metadata.subgraphs.append(subgraph) - return subgraph - - -@dataclass -class _LutCompressedArray: - compression_axis: Optional[int] = None - lookup_tables: list[np.ndarray] = field(default_factory=list) - indices: np.ndarray = field(default_factory=lambda: np.array([])) - - @property - def index_bitwidth(self) -> int: - """Returns the number of bits required to encode the indices.""" - if self.indices is None: - raise ValueError - - max_index = int(np.max(self.indices)) - return max_index.bit_length() or 1 - - -def _lut_compress_array(tensor: np.ndarray, - axis: Optional[int]) -> _LutCompressedArray: - """Compresses the given tensor using lookup tables. - - Args: - tensor (np.ndarray): The tensor to be compressed. - - axis (Optional[int]): The axis along which to compress the tensor. If an - axis is given, a lookup table is created for each slice along the - axis. If axis is None, a single lookup table is used for the entire - tensor. - - Compressing a tensor with a lookup table per slice along a - particular axis is analogous to quantizing a tensor with different - quantization parameters per slice along a particular axis (dimension). - - Returns: - _LutCompressedArray: An object containing the compressed tensor data, - including the lookup tables and indices. - """ - compressed = _LutCompressedArray() - compressed.compression_axis = axis - - if axis is None: - # Compute unique values and indices for the entire tensor - values, indices = np.unique(tensor, return_inverse=True) - compressed.lookup_tables.append(values) - compressed.indices = indices.reshape(tensor.shape) - else: - # Iterate over slices along the compression axis - slice_indices = [] - for slice in np.moveaxis(tensor, axis, 0): - values, indices = np.unique(slice, return_inverse=True) - compressed.lookup_tables.append(values) - indices = indices.reshape(slice.shape) - slice_indices.append(indices) - - # Reconstruct a tensor of indices from the slices - stacked = np.stack(slice_indices, axis=0) - compressed.indices = np.moveaxis(stacked, 0, axis) - - return compressed - - -def _check_lut_compression(compression) -> spec.LookUpTableCompression: - if len(compression) != 1: - raise CompressionError("Each tensor must have exactly one compression") - if not isinstance(compression[0], spec.LookUpTableCompression): - raise CompressionError('Only "lut" compression may be specified') - - return compression[0] - - -def _identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: - """Determines the axis along which to compress. - - The axis along which to compress is inferred from the tensor's quantization - parameters. - - Returns: - The axis along which to compress, or None to indicate one value table for - the entire tensor. - - Raises: - CompressionError: If the axis cannot be determined. - """ - q = tensor.quantization - if q is not None: - # model_editor wraps quantization, access scales/axis from wrapper - scales = q.scales if isinstance(q.scales, list) else [q.scales] - quantization_channels = len(scales) +Supported compression methods: - if quantization_channels == 1: - # Use one value table for the entire tensor - return None + lut: Look-Up-Table compression. Requires the tensor to have a small number of + unique values, fewer than or equal to 2**index_bitwidth. LUT compression + collects these values into a lookup table, and rewrites the tensor as + bitwidth-wide integer indices into that lookup table. - if q.axis is not None and q.axis < len(tensor.shape): - if quantization_channels == tensor.shape[q.axis]: - return q.axis + huffman: Huffman compression using Xtensa-format decode tables. (Not yet + implemented.) - raise CompressionError( - f"Invalid or no quanitzation parameters from which to " - f"infer the axis along which tensor should be compressed.") - - -def _check_bitwidth(compressed: int, specified: int, spec: spec.Tensor): - """Applies business logic regarding specified bitwidth. - - It is an error if the bitwidth required to compress a tensor exceeds the - specified bitwith, and a warning if the tensor can be compressed in less than - the specified bitwidth. The latter is allowed, and is not an error, to permit - testing with larger bitwidths without re-binning a model. - """ - if compressed > specified: - raise CompressionError( - f"index_bitwidth too small: {compressed} bits needed to " - f"enumerate unique values in tensor specified in {spec}") - elif compressed < specified: - print( - f"warning: index_bitwidth too large: only {compressed} " - f"bits needed to enumerate unique values in tensor specified in {spec}", - file=sys.stderr) - - -def _pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: - """Packs indices into a bytearray using bitwidth-sized fields. - """ - endianness = "big" - bits = bitarray.bitarray(endian=endianness) - for i in indices.ravel(): - bits.extend( - bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) - return bits.tobytes() + pruning: Pruning (sparsity) compression for sparse tensors. (Not yet + implemented.) +Compressed models use DECODE operators to decompress tensors at runtime. +""" -def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray: - """Packs the value tables of a LutCompressedArray. +# Plugin dispatch table: maps CompressionMethod subclasses to compressor instances +_COMPRESSORS: dict[Type[spec.CompressionMethod], compressor.Compressor] = { + spec.LookUpTableCompression: lut.LutCompressor(), + spec.HuffmanCompression: huffman.HuffmanCompressor(), + spec.PruningCompression: pruning.PruningCompressor(), +} - Pack the value tables of a LutCompressedArray into a bytes object in the - format writable to a value_table buffer in the .tflite flatbuffer. The - tables are concatenated. - """ - buffer = bytearray() - for t in tables: - padding_needed = table_len - len(t) - padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) - buffer.extend(padded.tobytes()) - return buffer +def _get_compressor(method: spec.CompressionMethod) -> compressor.Compressor: + """Get the compressor plugin for a given compression method.""" + compressor_instance = _COMPRESSORS.get(type(method)) + if compressor_instance is None: + raise compressor.CompressionError( + f"No compressor registered for {type(method).__name__}") + return compressor_instance def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: """Applies proper FlatBuffer alignment to a model. - + The Python flatbuffers library doesn't respect `force_align` schema attributes, so we use the C++ wrapper which properly handles alignment requirements. - + Args: model_bytes: The model flatbuffer to align - + Returns: The properly aligned model flatbuffer """ @@ -295,45 +122,58 @@ def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: """Compresses a model .tflite flatbuffer. + Compresses tensors according to the given specs and inserts DECODE operators + to decompress them at runtime. + Args: model_in: the original, uncompressed .tflite flatbuffer specs: an iterable of compression specs, see module spec.py Returns: - A compressed flatbuffer. + A compressed flatbuffer with DECODE operators inserted. """ model = model_editor.read(model_in) - metadata = _MetadataBuilder() + compression_results: dict[tuple[int, int], compressor.CompressionResult] = {} - for spec in specs: + for tensor_spec in specs: try: - tensor = model.subgraphs[spec.subgraph].tensors[spec.tensor] - lut_compression = _check_lut_compression(spec.compression) - spec_bitwidth = lut_compression.index_bitwidth - axis = _identify_compression_axis(tensor) - compressed = _lut_compress_array(tensor.array, axis) - _check_bitwidth(compressed.index_bitwidth, spec_bitwidth, spec) - - # overwrite tensor data with indices - tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth) - - # write value buffer - value_buffer_data = _pack_lookup_tables(compressed.lookup_tables, - 2**spec_bitwidth) - value_buffer = model_editor.Buffer(data=value_buffer_data) - model.buffers.append(value_buffer) # Auto-sets value_buffer.index - - # add compression metadata for tensor - lut_tensor = metadata.add_lut_tensor(subgraph_id=spec.subgraph) - lut_tensor.tensor = spec.tensor - lut_tensor.valueBuffer = value_buffer.index - lut_tensor.indexBitwidth = spec_bitwidth - + tensor = model.subgraphs[tensor_spec.subgraph].tensors[ + tensor_spec.tensor] + + # Currently only one compression method per tensor + if len(tensor_spec.compression) != 1: + raise compressor.CompressionError( + "Each tensor must have exactly one compression method") + + method = tensor_spec.compression[0] + plugin = _get_compressor(method) + original_size = len(tensor.buffer.data) if tensor.buffer.data else 0 + result = plugin.compress(tensor, method) + + compressed_size = len(result.encoded_data) + len(result.ancillary_data) + if compressed_size > original_size: + warnings.warn( + f"Compression of tensor {tensor.name!r} (subgraph " + f"{tensor_spec.subgraph}, tensor {tensor_spec.tensor}) resulted " + f"in expansion: {original_size} bytes -> {compressed_size} bytes " + f"(encoded: {len(result.encoded_data)}, " + f"ancillary: {len(result.ancillary_data)})", + stacklevel=2) + + # Replace tensor data with encoded data + tensor.buffer.data = result.encoded_data + + # Store result for DECODE insertion + compression_results[(tensor_spec.subgraph, tensor_spec.tensor)] = result + + except compressor.CompressionError: + raise except Exception as e: - raise CompressionError(f"error compressing {spec}") from e + raise compressor.CompressionError( + f"error compressing {tensor_spec}") from e - # add compression metadata to model - model.metadata[TFLITE_METADATA_KEY] = metadata.compile() + # Insert DECODE operators into the graph + decode_insert.insert_decode_operators(model, compression_results) # Build the model and apply proper alignment unaligned_model = model.build() diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index 81bbdab3293..cb241c2c62f 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -11,164 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Integration tests for the compression system.""" + +import warnings -import bitarray -import bitarray.util import numpy as np import unittest from tflite_micro.tensorflow.lite.micro.compression import compress -from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode_insert from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -class TestPackIndices(unittest.TestCase): - - def test_basic_case(self): - indices = np.array([1, 2, 3]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0001_0010, 0b0011_0000]) - self.assertEqual(result, expected_bytes) - - def test_single_element(self): - indices = np.array([10]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0000_1010]) - self.assertEqual(result, expected_bytes) - - def test_different_bitwidth(self): - indices = np.array([1, 2, 3]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0000_0001, 0b0000_0010, 0b0000_0011]) - self.assertEqual(result, expected_bytes) - - def test_large_numbers(self): - indices = np.array([255, 128, 64]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b1111_1111, 0b1000_0000, 0b0100_0000]) - self.assertEqual(result, expected_bytes) - - def test_multidimensional_array(self): - indices = np.array([[1, 2], [3, 4]]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0001_0010, 0b0011_0100]) - self.assertEqual(result, expected_bytes) - - def test_zero_bitwidth(self): - indices = np.array([0, 1, 2]) - bitwidth = 0 - with self.assertRaises(ValueError): - compress._pack_indices(indices, bitwidth) - - def test_empty_array(self): - indices = np.array([]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = b"" - self.assertEqual(result, expected_bytes) - - def test_bitwidth_1(self): - indices = np.array([1, 0, 1, 1, 0, 1]) - bitwidth = 1 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b101101_00]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_2(self): - indices = np.array([1, 2, 3, 0]) - bitwidth = 2 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b01_10_11_00]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_3(self): - indices = np.array([1, 3, 5, 7]) - bitwidth = 3 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b001_011_10, 0b1_111_0000]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_5(self): - indices = np.array([1, 2, 16, 31]) - bitwidth = 5 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b00001_000, 0b10_10000_1, 0b1111_0000]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_7(self): - indices = np.array([1, 64, 127, 32]) - bitwidth = 7 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes( - [0b0000001_1, 0b000000_11, 0b11111_010, 0b0000_0000]) - self.assertEqual(result, expected_bytes) - - -class TestPackLookupTables(unittest.TestCase): - - def test_int16_positive(self): - tables = [np.array([0x1234, 0x5678], dtype=' tuple[int, bitarray.bitarray, np.ndarray]: - """Helper: extracts the compressed tensor parts for a given spec. - - Returns: - bitwidth - indices - values - """ - subgraph_obj = self.compressed.subgraphs[subgraph] - tensor_obj = subgraph_obj.tensors[tensor] - lut_tensors = self.metadata.subgraphs[subgraph_obj.index].lutTensors - lut_tensor = next(t for t in lut_tensors if t.tensor == tensor_obj.index) - bitwidth = lut_tensor.indexBitwidth - - indices = bitarray.bitarray(buffer=tensor_obj.buffer.data, endian="big") - n_indices = np.prod(tensor_obj.shape) - indices = indices[:n_indices * bitwidth] # trim possible padding - - value_buffer = self.compressed.buffers[lut_tensor.valueBuffer] - values = np.frombuffer(value_buffer.data, dtype=tensor_obj.numpy_dtype) - - return bitwidth, indices, values - - def _make_indices(self, s: str) -> bitarray.bitarray: - """Helper: makes indices from "01" strings for use as expected values.""" - return bitarray.bitarray(s, endian="big") - - def test_compressed_uint8(self): - bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=0) - self.assertEqual(bitwidth, 4) - - # yapf: disable - expected_indices = self._make_indices(""" - 0000 0001 0010 0011 - 0100 0101 0110 0111 - 1000 1001 1010 1011 - 1100 1101 1110 1111 - """) - # yapf: enable - self.assertEqual(indices, expected_indices) - - expected_values = np.array(range(16), dtype=" Date: Sun, 24 May 2026 23:33:10 -0500 Subject: [PATCH 11/19] test(compression): add integration tests with TFLM interpreter Add tests that compress models with LUT compression, run them through the TFLM Python interpreter, and verify outputs match uncompressed originals. Cover per-tensor and per-channel quantization, various index bitwidths, unquantized weights, and alternate decompression memory. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 25 + .../compression_integration_test.py | 505 ++++++++++++++++++ 2 files changed, 530 insertions(+) create mode 100644 tensorflow/lite/micro/compression/compression_integration_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 3478d584b56..7034f8b44fd 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -169,6 +169,31 @@ py_test( ], ) +tflm_py_test( + name = "compression_integration_test", + size = "small", + srcs = ["compression_integration_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + # Only run when compression IS enabled + target_compatible_with = select({ + "//:with_compression_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":compress_lib", + ":decode_insert", + ":model_editor", + ":spec", + "//python/tflite_micro:runtime", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + tflm_py_library( name = "spec", srcs = ["spec.py"], diff --git a/tensorflow/lite/micro/compression/compression_integration_test.py b/tensorflow/lite/micro/compression/compression_integration_test.py new file mode 100644 index 00000000000..0e92a527f6a --- /dev/null +++ b/tensorflow/lite/micro/compression/compression_integration_test.py @@ -0,0 +1,505 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Integration tests for compression with TFLM interpreter. + +These tests verify that compressed models produce correct inference results +when run through the TFLM Python interpreter. Tests compress models and +compare outputs against uncompressed originals. + +These tests only run when compression is enabled (--//:with_compression). +""" + +import os +import unittest +import numpy as np + +from tflite_micro.python.tflite_micro import runtime +from tflite_micro.tensorflow.lite.micro.compression import compress +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +def _build_compressible_model(weight_shape=(4, 4), + index_bitwidth=2, + per_channel=False, + unquantized=False): + """Build a model with clustered weights for compression testing. + + Args: + weight_shape: Shape of the weight tensor as (rows, cols). + index_bitwidth: Bits per index. Determines unique value count (2^bitwidth). + per_channel: If True, use per-channel quantization (one scale per row). + unquantized: If True, omit quantization from weights. + + Returns: + A TFLite flatbuffer (bytes) containing a simple FULLY_CONNECTED model + with weights that have limited unique values per channel. + """ + rows, cols = weight_shape + unique_count = 2**index_bitwidth + + # Create weights with limited unique values per channel + pattern = np.arange(1, unique_count + 1, dtype=np.int8) + weight_data = np.resize(pattern, (rows, cols)) + + if unquantized: + quantization = None + elif per_channel: + # Per-channel: one scale per output channel (row in FC weights) + scales = [0.5 + 0.1 * i for i in range(rows)] + zero_points = [0] * rows + quantization = model_editor.Quantization( + scales=scales, + zero_points=zero_points, + axis=0, + ) + else: + quantization = model_editor.Quantization(scales=0.5, zero_points=0) + + weights = model_editor.Tensor( + shape=weight_shape, + dtype=tflite.TensorType.INT8, + data=weight_data, + name="weights", + quantization=quantization, + ) + + input_t = model_editor.Tensor( + shape=(1, cols), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, rows), + dtype=tflite.TensorType.INT8, + name="output", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + inputs=[input_t], + outputs=[output_t], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + return model.build() + + +class LutCompressionTest(unittest.TestCase): + """Integration tests for LUT (lookup table) compression.""" + + def test_lut_compressed_model_matches_uncompressed(self): + """LUT-compressed model produces same outputs as uncompressed.""" + flatbuffer = _build_compressible_model() + + # Create compression spec for weights tensor (index 0 in tensors list) + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + # Compress + compressed_fb = compress.compress(flatbuffer, specs) + + # Run inference on both (convert bytearray to bytes for interpreter) + uncompressed_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Test with multiple random inputs + np.random.seed(42) + for _ in range(10): + test_input = np.random.randint(-128, 127, (1, 4), dtype=np.int8) + + uncompressed_interp.set_input(test_input, 0) + uncompressed_interp.invoke() + expected = uncompressed_interp.get_output(0) + + compressed_interp.set_input(test_input, 0) + compressed_interp.invoke() + actual = compressed_interp.get_output(0) + + np.testing.assert_array_equal(expected, actual) + + def test_lut_decode_operators_present(self): + """DECODE operators are inserted for LUT-compressed tensors.""" + flatbuffer = _build_compressible_model() + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + model = model_editor.read(compressed_fb) + sg = model.subgraphs[0] + + # Find DECODE operators + decode_ops = [ + op for op in sg.operators if op.opcode == tflite.BuiltinOperator.CUSTOM + and op.custom_code == decode_insert.DECODE_CUSTOM_OP_NAME + ] + + self.assertEqual(len(decode_ops), 1) + + def test_lut_compressed_model_is_smaller(self): + """LUT-compressed model is smaller than original. + + Uses a large enough weight tensor (64x64 = 4096 bytes) that compression + savings outweigh the overhead from lookup tables and DECODE operators. + With 2-bit indices, 4096 bytes becomes 1024 bytes of indices. + """ + flatbuffer = _build_compressible_model(weight_shape=(64, 64)) + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + original_size = len(flatbuffer) + compressed_size = len(compressed_fb) + + self.assertLess( + compressed_size, original_size, + f"Compressed model ({compressed_size} bytes) should be smaller than " + f"original ({original_size} bytes)") + + def test_lut_4bit_compression(self): + """4-bit LUT compression produces correct inference results.""" + flatbuffer = _build_compressible_model(index_bitwidth=4) + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=4)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + uncompressed_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + test_input = np.array([[1, 2, 3, 4]], dtype=np.int8) + + uncompressed_interp.set_input(test_input, 0) + uncompressed_interp.invoke() + expected = uncompressed_interp.get_output(0) + + compressed_interp.set_input(test_input, 0) + compressed_interp.invoke() + actual = compressed_interp.get_output(0) + + np.testing.assert_array_equal(expected, actual) + + def test_lut_per_channel_quantization(self): + """Per-channel quantized weights compress and decompress correctly.""" + flatbuffer = _build_compressible_model(per_channel=True) + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + uncompressed_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + test_input = np.array([[1, 2, 3, 4]], dtype=np.int8) + + uncompressed_interp.set_input(test_input, 0) + uncompressed_interp.invoke() + expected = uncompressed_interp.get_output(0) + + compressed_interp.set_input(test_input, 0) + compressed_interp.invoke() + actual = compressed_interp.get_output(0) + + np.testing.assert_array_equal(expected, actual) + + def test_lut_unquantized_weights(self): + """Unquantized weights compress and decompress correctly.""" + flatbuffer = _build_compressible_model(unquantized=True) + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + uncompressed_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + test_input = np.array([[1, 2, 3, 4]], dtype=np.int8) + + uncompressed_interp.set_input(test_input, 0) + uncompressed_interp.invoke() + expected = uncompressed_interp.get_output(0) + + compressed_interp.set_input(test_input, 0) + compressed_interp.invoke() + actual = compressed_interp.get_output(0) + + np.testing.assert_array_equal(expected, actual) + + +def _build_shared_weights_model(): + """Build a model where one compressed tensor is shared between two operators. + + Model structure: + input1 -> [FC1 with weights1] -> output1 + input2 -> [FC2 with weights2] -> intermediate -> [FC3 with weights1] -> output2 + + weights1 is shared between FC1 and FC3. weights2 is used only by FC2, which + runs between the two consumers of weights1. + """ + # 4 unique values per tensor for 2-bit LUT compression. Small values avoid + # saturation in chained layers. Different row sums produce varied outputs. + weights1_data = np.array([ + [-1, 0, 0, 1], + [-1, 0, 1, 1], + [-1, 1, 1, 1], + [0, 1, 1, 1], + ], + dtype=np.int8) + weights1 = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=weights1_data, + name="weights1", + quantization=model_editor.Quantization(scales=1.0, zero_points=0), + ) + + weights2_data = np.array([ + [1, 1, 1, 1], + [1, 1, 2, 2], + [1, 2, 2, 3], + [2, 2, 3, 3], + ], + dtype=np.int8) + weights2 = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=weights2_data, + name="weights2", + quantization=model_editor.Quantization(scales=1.0, zero_points=0), + ) + + # All tensors need matching quantization for FULLY_CONNECTED + quant = model_editor.Quantization(scales=1.0, zero_points=0) + + input1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input1", + quantization=quant, + ) + input2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input2", + quantization=quant, + ) + output1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output1", + quantization=quant, + ) + intermediate = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="intermediate", + quantization=quant, + ) + output2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output2", + quantization=quant, + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights1, weights2], + inputs=[input1, input2], + outputs=[output1, output2], + operators=[ + # FC1: uses weights1 + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input1, weights1], + outputs=[output1], + ), + # FC2: uses weights2 (runs between FC1 and FC3) + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input2, weights2], + outputs=[intermediate], + ), + # FC3: uses weights1 (second consumer, after DECODE(weights2)) + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[intermediate, weights1], + outputs=[output2], + ), + ], + ) + ]) + return model.build() + + +class AltDecompressionMemoryTest(unittest.TestCase): + """Tests for alternate decompression memory with shared compressed tensors. + + These tests verify correct behavior when compressed tensors are shared + between multiple operators and alternate decompression memory is enabled. + """ + + def test_shared_compressed_tensor_with_alt_memory(self): + """Verify correct results when a shared compressed tensor is used with alt + decompression memory. + + This test uses a graph where a compressed tensor (weights1) is consumed by + two operators (FC1 and FC3), with an intervening DECODE of a different + compressed tensor (weights2) between them. + + The interpreter's alternate decompression memory has a limitation: each + DECODE's Prepare resets the allocation offset to zero. This means all + DECODE outputs are allocated at the same address, so they overwrite each + other. A DECODE output can only be used until the next DECODE runs. + + To work around this limitation, the DECODE insertion code inserts a + separate DECODE immediately before each consumer of a compressed tensor, + rather than sharing one DECODE output among all consumers. + """ + flatbuffer = _build_shared_weights_model() + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, # weights1 + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ), + spec.Tensor( + subgraph=0, + tensor=1, # weights2 + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ), + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + # Run without alt decompression memory (baseline) + interp_no_alt = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Run with alt decompression memory + interp_with_alt = runtime.Interpreter.from_bytes( + bytes(compressed_fb), + alt_decompression_memory_size=256, + ) + + test_input1 = np.array([[1, 1, 1, 1]], dtype=np.int8) + test_input2 = np.array([[1, 1, 1, 1]], dtype=np.int8) + + interp_no_alt.set_input(test_input1, 0) + interp_no_alt.set_input(test_input2, 1) + interp_no_alt.invoke() + expected1 = interp_no_alt.get_output(0) + expected2 = interp_no_alt.get_output(1) + + interp_with_alt.set_input(test_input1, 0) + interp_with_alt.set_input(test_input2, 1) + interp_with_alt.invoke() + actual1 = interp_with_alt.get_output(0) + actual2 = interp_with_alt.get_output(1) + + np.testing.assert_array_equal( + expected1, actual1, "Output 1 mismatch with alt decompression memory") + np.testing.assert_array_equal( + expected2, actual2, "Output 2 mismatch with alt decompression memory") + + +class HuffmanCompressionTest(unittest.TestCase): + """Integration tests for Huffman compression.""" + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_compressed_model_matches_uncompressed(self): + """Huffman-compressed model produces same outputs as uncompressed.""" + pass + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_decode_operators_present(self): + """DECODE operators are inserted for Huffman-compressed tensors.""" + pass + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_compressed_model_is_smaller(self): + """Huffman-compressed model is smaller than original.""" + pass + + +class PruningCompressionTest(unittest.TestCase): + """Integration tests for pruning compression.""" + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_compressed_model_matches_uncompressed(self): + """Pruning-compressed model produces same outputs as uncompressed.""" + pass + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_decode_operators_present(self): + """DECODE operators are inserted for pruning-compressed tensors.""" + pass + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_compressed_model_is_smaller(self): + """Pruning-compressed model is smaller than original.""" + pass + + +if __name__ == "__main__": + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + # Disable oneDNN to avoid non-deterministic floating point results + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + unittest.main() From ffae7dddfed5b1a4382c8f027eaf742e5a5777e8 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:36:40 -0500 Subject: [PATCH 12/19] test(compression): add proprietary model integration test Add a manual test for verifying compression on proprietary models that can't be checked into the repository. See the module docstring for usage instructions. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 24 ++ .../proprietary_integration_test.py | 211 ++++++++++++++++++ 2 files changed, 235 insertions(+) create mode 100644 tensorflow/lite/micro/compression/proprietary_integration_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 7034f8b44fd..d8c017203cf 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -194,6 +194,30 @@ tflm_py_test( ], ) +tflm_py_test( + name = "proprietary_integration_test", + size = "small", + srcs = ["proprietary_integration_test.py"], + tags = [ + "manual", + "noasan", + "nomsan", + "noubsan", + ], + target_compatible_with = select({ + "//:with_compression_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":compress_lib", + ":model_editor", + ":spec", + "//python/tflite_micro:runtime", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + ], +) + tflm_py_library( name = "spec", srcs = ["spec.py"], diff --git a/tensorflow/lite/micro/compression/proprietary_integration_test.py b/tensorflow/lite/micro/compression/proprietary_integration_test.py new file mode 100644 index 00000000000..684805d0f56 --- /dev/null +++ b/tensorflow/lite/micro/compression/proprietary_integration_test.py @@ -0,0 +1,211 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Integration tests for compression using proprietary models. + +These tests verify that compressed models produce correct inference results +when run through the TFLM Python interpreter. Tests compress models and +compare outputs against uncompressed originals using random inputs. + +This test is tagged `manual` and requires a path to a directory containing +.tflite model files. + +Usage: + bazel test //tensorflow/lite/micro/compression:proprietary_integration_test \ + --//:with_compression \ + --test_arg=/path/to/models + +Required files: + Each model requires a compression spec file: + model.spec.yaml (replacing .tflite extension) + + See spec.py for the YAML format. Example: + tensors: + - subgraph: 0 + tensor: 2 + compression: + - lut: + index_bitwidth: 4 + +Optional files: + model.config.json (replacing .tflite extension) + Tolerance overrides: {"rtol": 1e-5, "atol": 1e-6} + Default is exact match (rtol=0, atol=0). +""" + +import glob +import json +import os +import sys +import unittest + +import numpy as np + +from tflite_micro.python.tflite_micro import runtime +from tflite_micro.tensorflow.lite.micro.compression import compress +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +def _dtype_to_numpy(dtype: tflite.TensorType) -> np.dtype: + """Convert TFLite dtype to numpy dtype.""" + type_map = { + tflite.TensorType.INT8: np.int8, + tflite.TensorType.INT16: np.int16, + tflite.TensorType.INT32: np.int32, + tflite.TensorType.INT64: np.int64, + tflite.TensorType.UINT8: np.uint8, + tflite.TensorType.UINT16: np.uint16, + tflite.TensorType.UINT32: np.uint32, + tflite.TensorType.FLOAT16: np.float16, + tflite.TensorType.FLOAT32: np.float32, + tflite.TensorType.FLOAT64: np.float64, + tflite.TensorType.BOOL: np.bool_, + } + return type_map.get(dtype, np.uint8) + + +class ProprietaryModelTest(unittest.TestCase): + """Integration tests using proprietary models.""" + + # Parsed from command line in main() + models_dir = None + + @classmethod + def setUpClass(cls): + if not cls.models_dir: + raise unittest.SkipTest( + "No models directory provided. " + "Usage: bazel test ... --test_arg=/path/to/models") + + cls.model_paths = sorted( + glob.glob(os.path.join(cls.models_dir, '*.tflite'))) + if not cls.model_paths: + raise unittest.SkipTest(f"No .tflite files found in {cls.models_dir}") + + def test_all_models(self): + """Run compression test on each discovered model.""" + for model_path in self.model_paths: + with self.subTest(model=os.path.basename(model_path)): + self._test_model_compression(model_path) + + def _test_model_compression(self, model_path): + """Test that a compressed model produces same outputs as original.""" + with open(model_path, 'rb') as f: + flatbuffer = f.read() + + # Load compression spec from sidecar file + specs = self._load_compression_spec(model_path) + + # Load tolerance config + rtol, atol = self._load_tolerance(model_path) + + # Compress the model + compressed_fb = compress.compress(flatbuffer, specs) + + # Create interpreters + original_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Generate random inputs and compare outputs + np.random.seed(42) + model = model_editor.read(flatbuffer) + sg = model.subgraphs[0] + + for trial in range(5): + # Set inputs + for i, input_tensor in enumerate(sg.inputs): + test_input = self._generate_input(input_tensor) + original_interp.set_input(test_input, i) + compressed_interp.set_input(test_input, i) + + # Run inference + original_interp.invoke() + compressed_interp.invoke() + + # Compare outputs + for i in range(len(sg.outputs)): + expected = original_interp.get_output(i) + actual = compressed_interp.get_output(i) + self._compare_outputs(expected, actual, rtol, atol, + f"trial {trial}, output {i}") + + def _generate_input(self, tensor): + """Generate random input respecting tensor dtype.""" + shape = tensor.shape + dtype = _dtype_to_numpy(tensor.dtype) + + if np.issubdtype(dtype, np.floating): + return np.random.uniform(-1.0, 1.0, shape).astype(dtype) + elif np.issubdtype(dtype, np.integer): + info = np.iinfo(dtype) + return np.random.randint(info.min, info.max + 1, shape, dtype=dtype) + elif dtype == np.bool_: + return np.random.choice([False, True], shape) + return np.zeros(shape, dtype=dtype) + + def _load_compression_spec(self, model_path): + """Load compression spec from sidecar YAML file. + + Raises: + FileNotFoundError: If no spec file is found. + """ + spec_path = model_path.replace('.tflite', '.spec.yaml') + if os.path.exists(spec_path): + with open(spec_path) as f: + return spec.parse_yaml(f.read()) + + raise FileNotFoundError( + f"No compression spec file found for {model_path}. " + f"Expected: {spec_path}") + + def _load_tolerance(self, model_path): + """Load tolerance from sidecar config if present. + + Returns (0, 0) for exact match if no config file exists. + """ + config_path = model_path.replace('.tflite', '.config.json') + if os.path.exists(config_path): + with open(config_path) as f: + config = json.load(f) + return config.get('rtol', 0), config.get('atol', 0) + return 0, 0 + + def _compare_outputs(self, expected, actual, rtol, atol, context=""): + """Compare outputs with optional tolerance.""" + msg = f"Output mismatch ({context})" if context else "Output mismatch" + if rtol == 0 and atol == 0: + np.testing.assert_array_equal(expected, actual, err_msg=msg) + else: + np.testing.assert_allclose(expected, + actual, + rtol=rtol, + atol=atol, + err_msg=msg) + + +if __name__ == "__main__": + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + # Disable oneDNN to avoid non-deterministic floating point results + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + + # Parse models directory from args, then strip it so tf.test doesn't see it + for arg in sys.argv[1:]: + if not arg.startswith('-') and os.path.isdir(arg): + ProprietaryModelTest.models_dir = arg + sys.argv.remove(arg) + break + + unittest.main() From c4993089badc7fb86b6aab78a4bddd5829679d80 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Sun, 24 May 2026 23:38:01 -0500 Subject: [PATCH 13/19] refactor(compression): compressors inherit from Compressor protocol Explicit inheritance from Protocol enables static type checking at definition time and makes the interface self-documenting. BUG=part of #3256 --- tensorflow/lite/micro/compression/huffman.py | 2 +- tensorflow/lite/micro/compression/lut.py | 2 +- tensorflow/lite/micro/compression/pruning.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/compression/huffman.py b/tensorflow/lite/micro/compression/huffman.py index 40d0be9284a..e539827eae4 100644 --- a/tensorflow/lite/micro/compression/huffman.py +++ b/tensorflow/lite/micro/compression/huffman.py @@ -25,7 +25,7 @@ from tflite_micro.tensorflow.lite.micro.compression import spec -class HuffmanCompressor: +class HuffmanCompressor(compressor.Compressor): """Huffman compression plugin (stub). This stub exists to validate the plugin architecture. The actual Huffman diff --git a/tensorflow/lite/micro/compression/lut.py b/tensorflow/lite/micro/compression/lut.py index def34059ac5..991288f54cc 100644 --- a/tensorflow/lite/micro/compression/lut.py +++ b/tensorflow/lite/micro/compression/lut.py @@ -241,7 +241,7 @@ def pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytes: return bytes(buffer) -class LutCompressor: +class LutCompressor(compressor.Compressor): """LUT compression plugin implementing the Compressor protocol.""" @property diff --git a/tensorflow/lite/micro/compression/pruning.py b/tensorflow/lite/micro/compression/pruning.py index 2181b73e34a..5c95e3e87e9 100644 --- a/tensorflow/lite/micro/compression/pruning.py +++ b/tensorflow/lite/micro/compression/pruning.py @@ -25,7 +25,7 @@ from tflite_micro.tensorflow.lite.micro.compression import spec -class PruningCompressor: +class PruningCompressor(compressor.Compressor): """Pruning compression plugin (stub). This stub exists to validate the plugin architecture. The actual pruning From fc92d3cfec5d7456f2d2f6b11ae4f07d7bb191a7 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:05:51 -0500 Subject: [PATCH 14/19] test(python): rewrite unsupported-compression test for legacy path An upcoming change registers the DECODE operator unconditionally in the Python ops resolver, after which compress() emits DECODE-based models that load successfully. That breaks this test's original approach, which ran a model through compress() and expected the load to fail. Rewrite it to instead inject a raw COMPRESSION_METADATA entry into the flatbuffer via model_editor, directly exercising the HasCompressionMetadata() detection path for legacy-compressed models. Decoupling the test from compress() output lets it verify the legacy-rejection behavior independently of whether the DECODE operator is registered, so it passes both before and after that upcoming change. BUG=part of #3256 --- python/tflite_micro/BUILD | 2 +- .../test_compression_unsupported.py | 96 +++++++++---------- 2 files changed, 48 insertions(+), 50 deletions(-) diff --git a/python/tflite_micro/BUILD b/python/tflite_micro/BUILD index b358fd12adc..812cf7092fd 100644 --- a/python/tflite_micro/BUILD +++ b/python/tflite_micro/BUILD @@ -125,7 +125,7 @@ py_test( ":runtime", requirement("numpy"), requirement("tensorflow"), - "//tensorflow/lite/micro/compression", + "//tensorflow/lite/micro/compression:model_editor", ], ) diff --git a/python/tflite_micro/test_compression_unsupported.py b/python/tflite_micro/test_compression_unsupported.py index 3692dd0a43a..edd47808298 100644 --- a/python/tflite_micro/test_compression_unsupported.py +++ b/python/tflite_micro/test_compression_unsupported.py @@ -12,84 +12,82 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Test compression metadata detection when compression is disabled.""" +"""Test legacy compression metadata detection when compression is disabled.""" import os import numpy as np import tensorflow as tf from tflite_micro.python.tflite_micro import runtime -from tflite_micro.tensorflow.lite.micro import compression +from tflite_micro.tensorflow.lite.micro.compression import model_editor -class CompressionDetectionTest(tf.test.TestCase): - """Test compression metadata detection when compression is disabled.""" +def _create_test_model(): + """Create a simple quantized model for testing.""" + model = tf.keras.Sequential([ + tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'), + tf.keras.layers.Dense(5, activation='softmax') + ]) + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') - def _create_test_model(self): - """Create a simple quantized model for testing.""" - model = tf.keras.Sequential([ - tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'), - tf.keras.layers.Dense(5, activation='softmax') - ]) - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] - # Convert to quantized TFLite - converter = tf.lite.TFLiteConverter.from_keras_model(model) - converter.optimizations = [tf.lite.Optimize.DEFAULT] + def representative_dataset(): + for _ in range(10): + yield [np.random.randn(1, 5).astype(np.float32)] - def representative_dataset(): - for _ in range(10): - yield [np.random.randn(1, 5).astype(np.float32)] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 - converter.representative_dataset = representative_dataset - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 + tflite_model = converter.convert() + return bytes(tflite_model) if isinstance(tflite_model, + bytearray) else tflite_model - tflite_model = converter.convert() - return bytes(tflite_model) if isinstance(tflite_model, - bytearray) else tflite_model + +def _inject_compression_metadata(model_data): + """Inject raw COMPRESSION_METADATA into a model's flatbuffer metadata. + + This simulates a legacy-compressed model (one that uses the + COMPRESSION_METADATA metadata entry and kernel-level decompression) without + going through compress(), which now produces DECODE-based output. + """ + model = model_editor.read(model_data) + model.metadata["COMPRESSION_METADATA"] = b"\x00" + return bytes(model.build()) + + +class LegacyCompressionDetectionTest(tf.test.TestCase): + """Test that legacy COMPRESSION_METADATA is rejected without the flag.""" def test_regular_model_loads_successfully(self): """Non-compressed models should load without issues.""" - model_data = self._create_test_model() + model_data = _create_test_model() interpreter = runtime.Interpreter.from_bytes(model_data) self.assertIsNotNone(interpreter) - def test_compressed_model_raises_runtime_error(self): - """Compressed models should raise RuntimeError when compression is disabled.""" - # Create and compress a model - model_data = self._create_test_model() + def test_legacy_compressed_model_raises_runtime_error(self): + """Models with COMPRESSION_METADATA should raise RuntimeError.""" + model_data = _create_test_model() + legacy_model = _inject_compression_metadata(model_data) - spec = (compression.SpecBuilder().add_tensor( - subgraph=0, tensor=1).with_lut(index_bitwidth=4).build()) - - compressed_model = compression.compress(model_data, spec) - if isinstance(compressed_model, bytearray): - compressed_model = bytes(compressed_model) - - # Should raise RuntimeError with self.assertRaises(RuntimeError): - runtime.Interpreter.from_bytes(compressed_model) - - def test_can_load_regular_after_compressed_failure(self): - """Verify we can still load regular models after compressed model fails.""" - model_data = self._create_test_model() + runtime.Interpreter.from_bytes(legacy_model) - # First try compressed model (should fail) - spec = (compression.SpecBuilder().add_tensor( - subgraph=0, tensor=1).with_lut(index_bitwidth=4).build()) - compressed_model = compression.compress(model_data, spec) + def test_can_load_regular_after_legacy_failure(self): + """Verify regular models still load after a legacy-compressed failure.""" + model_data = _create_test_model() + legacy_model = _inject_compression_metadata(model_data) with self.assertRaises(RuntimeError): - runtime.Interpreter.from_bytes(bytes(compressed_model)) + runtime.Interpreter.from_bytes(legacy_model) - # Then load regular model (should succeed) interpreter = runtime.Interpreter.from_bytes(model_data) self.assertIsNotNone(interpreter) if __name__ == '__main__': - # Set TF environment variables to suppress warnings os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' tf.test.main() From 8871caf7347bfbaaba88f4f371ec26403376e54e Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:06:40 -0500 Subject: [PATCH 15/19] feat(python): register DECODE op unconditionally The DECODE kernel and its dependencies are already compiled unconditionally -- none are guarded by USE_TFLM_COMPRESSION. Remove the #ifdef around AddDecode() in PythonOpsResolver so DECODE-based compressed models work in a default Python build. Remove the with_compression_enabled gating from compression and proprietary integration tests, since they use DECODE-based models that no longer require the flag. BUG=part of #3256 --- python/tflite_micro/python_ops_resolver.cc | 2 -- tensorflow/lite/micro/compression/BUILD | 11 ++--------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/python/tflite_micro/python_ops_resolver.cc b/python/tflite_micro/python_ops_resolver.cc index 19f324bdf2f..5f7d40fb74e 100644 --- a/python/tflite_micro/python_ops_resolver.cc +++ b/python/tflite_micro/python_ops_resolver.cc @@ -40,9 +40,7 @@ PythonOpsResolver::PythonOpsResolver() { AddConv2D(); AddCos(); AddCumSum(); -#ifdef USE_TFLM_COMPRESSION AddDecode(); -#endif AddDelay(); AddDepthToSpace(); AddDepthwiseConv2D(); diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index d8c017203cf..c8e313f81a6 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -178,11 +178,7 @@ tflm_py_test( "nomsan", "noubsan", ], - # Only run when compression IS enabled - target_compatible_with = select({ - "//:with_compression_enabled": [], - "//conditions:default": ["@platforms//:incompatible"], - }), + target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ ":compress_lib", ":decode_insert", @@ -204,10 +200,7 @@ tflm_py_test( "nomsan", "noubsan", ], - target_compatible_with = select({ - "//:with_compression_enabled": [], - "//conditions:default": ["@platforms//:incompatible"], - }), + target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ ":compress_lib", ":model_editor", From aec05b814871d7bfb22fd51962eeb8efaf6e9283 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:09:21 -0500 Subject: [PATCH 16/19] test(compression): add tests for batched DECODE insertion Add test_multiple_compressed_inputs_batched: a CONCATENATION with two compressed tensor inputs, each with a different bitwidth, should produce a single DECODE with 4 inputs and 2 outputs, each ancillary tensor carrying its own distinct data. Marked expectedFailure until the implementation lands. Add test_mixed_compressed_and_uncompressed_inputs: a CONCATENATION with one compressed and one plain input leaves the plain input untouched. This already passes with the current code. BUG=part of #3256 --- tensorflow/lite/micro/compression/BUILD | 1 + .../micro/compression/decode_insert_test.py | 153 +++++++++++++++++- 2 files changed, 149 insertions(+), 5 deletions(-) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index c8e313f81a6..a9fe9fa36de 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -392,6 +392,7 @@ tflm_py_test( ":compressor", ":decode", ":decode_insert", + ":lut", ":model_editor", "//tensorflow/lite/python:schema_py", requirement("numpy"), diff --git a/tensorflow/lite/micro/compression/decode_insert_test.py b/tensorflow/lite/micro/compression/decode_insert_test.py index 11be81963d9..a7e1fb25e8d 100644 --- a/tensorflow/lite/micro/compression/decode_insert_test.py +++ b/tensorflow/lite/micro/compression/decode_insert_test.py @@ -13,14 +13,15 @@ # limitations under the License. """Unit tests for DECODE operator insertion.""" +import unittest import warnings import numpy as np -import unittest from tflite_micro.tensorflow.lite.micro.compression import compressor from tflite_micro.tensorflow.lite.micro.compression import decode from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import lut from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite @@ -115,14 +116,22 @@ def _build_shared_weights_model(): return model -def _make_dummy_ancillary_data() -> bytes: +def _make_dummy_ancillary_data(bitwidth=4) -> bytes: """Create dummy ancillary data for testing.""" + n_entries = 1 << bitwidth + value_tables = bytes(range(1, n_entries + 1)) + value_tables += b'\x00' * ((-len(value_tables)) % 16) + + lut_data = lut.LutAncillaryData( + bitwidth=bitwidth, + value_table_stride=n_entries, + value_tables=value_tables, + ) dcm = decode.DecodeCommonMetadata( decode_type=decode.DecodeType.LUT, - user_data=b'\x01\x04\x10' + b'\x00' * 9, # lut_version, bitwidth, stride + user_data=lut_data.to_user_data(), ) - value_tables = bytes([1, 2, 3, 4] + [0] * 12) # 16-byte padded table - return dcm.to_bytes() + value_tables + return dcm.to_bytes() + lut_data.to_bytes() class TestDecodeInsertion(unittest.TestCase): @@ -376,6 +385,140 @@ def test_tensor_naming(self): self.assertEqual(ancillary.name, "weights_ancillary") self.assertEqual(output.name, "weights_decoded") + @unittest.expectedFailure + def test_multiple_compressed_inputs_batched(self): + """CONCATENATION with two compressed inputs gets one batched DECODE.""" + weights_a = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights_a", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + weights_b = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights_b", + quantization=model_editor.Quantization(scales=0.25, zero_points=0), + ) + output_t = model_editor.Tensor( + shape=(4, 8), + dtype=tflite.TensorType.INT8, + name="output", + ) + + concat_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CONCATENATION, + inputs=[weights_a, weights_b], + outputs=[output_t], + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights_a, weights_b], + operators=[concat_op], + ) + ]) + + ancillary_a = _make_dummy_ancillary_data(bitwidth=2) + ancillary_b = _make_dummy_ancillary_data(bitwidth=4) + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x01', + ancillary_data=ancillary_a, + ), + (0, 1): + compressor.CompressionResult( + encoded_data=b'\x02\x03', + ancillary_data=ancillary_b, + ), + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # One DECODE + one CONCATENATION + self.assertEqual(len(sg.operators), 2) + decode_op = sg.operators[0] + self.assertEqual(decode_op.opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(decode_op.custom_code, + decode_insert.DECODE_CUSTOM_OP_NAME) + + # DECODE has 4 inputs (enc_a, anc_a, enc_b, anc_b) and 2 outputs + self.assertEqual(len(decode_op.inputs), 4) + self.assertEqual(len(decode_op.outputs), 2) + + # Each ancillary tensor carries its own distinct data + self.assertNotEqual(ancillary_a, ancillary_b) + self.assertEqual(bytes(decode_op.inputs[1].array), ancillary_a) + self.assertEqual(bytes(decode_op.inputs[3].array), ancillary_b) + + # CONCATENATION rewired to DECODE outputs + self.assertIs(sg.operators[1].inputs[0], decode_op.outputs[0]) + self.assertIs(sg.operators[1].inputs[1], decode_op.outputs[1]) + + def test_mixed_compressed_and_uncompressed_inputs(self): + """CONCATENATION with one compressed and one plain input.""" + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + plain = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.zeros((4, 4), dtype=np.int8), + name="plain", + ) + output_t = model_editor.Tensor( + shape=(4, 8), + dtype=tflite.TensorType.INT8, + name="output", + ) + + concat_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CONCATENATION, + inputs=[weights, plain], + outputs=[output_t], + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights, plain], + operators=[concat_op], + ) + ]) + + # Only compress weights, not plain + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x01', + ancillary_data=_make_dummy_ancillary_data(), + ), + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # One DECODE + one CONCATENATION + self.assertEqual(len(sg.operators), 2) + decode_op = sg.operators[0] + + # DECODE has 2 inputs and 1 output (only the compressed tensor) + self.assertEqual(len(decode_op.inputs), 2) + self.assertEqual(len(decode_op.outputs), 1) + + # CONCATENATION: first input rewired to DECODE output, second unchanged + self.assertIs(sg.operators[1].inputs[0], decode_op.outputs[0]) + self.assertIs(sg.operators[1].inputs[1], plain) + def test_encoded_tensor_rewritten(self): """Compressed tensor is rewritten with encoded data, UINT8 type, no quant.""" model = _build_simple_fc_model() From f26026c55fe974368da0fd73c369feb6b57a8ee1 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:39:13 -0500 Subject: [PATCH 17/19] feat(compression): batch multiple compressed tensors per DECODE When a single operator (e.g., CONCATENATION) has multiple compressed tensor inputs, group them into one DECODE instead of creating a separate DECODE for each. Grouping is per-consumer, so a tensor shared across different consumers still gets a separate DECODE before each one to avoid clobbering the alternate decompression memory. BUG=part of #3256 --- .../lite/micro/compression/decode_insert.py | 72 +++++++++++-------- .../micro/compression/decode_insert_test.py | 1 - 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/tensorflow/lite/micro/compression/decode_insert.py b/tensorflow/lite/micro/compression/decode_insert.py index 43dffce46f0..fa91896e538 100644 --- a/tensorflow/lite/micro/compression/decode_insert.py +++ b/tensorflow/lite/micro/compression/decode_insert.py @@ -210,15 +210,20 @@ def insert_decode_operators( for sg_idx, tensor_infos in by_subgraph.items(): subgraph = model.subgraphs[sg_idx] - # Collect all (consumer, tensor_info) pairs and sort by consumer position - # in reverse order so insertions don't invalidate positions - consumer_pairs = [] + # Group tensor infos by consumer so multiple compressed inputs to the + # same operator get batched into a single DECODE. + consumer_to_infos: dict[model_editor.Operator, list[_CompressedTensorInfo]] + consumer_to_infos = defaultdict(list) for info in tensor_infos: for consumer in info.consumers: - consumer_pairs.append((consumer, info)) - - consumer_pairs.sort( - key=lambda pair: subgraph.operators.index(pair[0]), + if info not in consumer_to_infos[consumer]: + consumer_to_infos[consumer].append(info) + + # Sort consumers by position in reverse so insertions don't invalidate + # earlier positions. + sorted_consumers = sorted( + consumer_to_infos.keys(), + key=lambda op: subgraph.operators.index(op), reverse=True, ) @@ -231,38 +236,45 @@ def insert_decode_operators( # _create_output_tensor reads the original tensor's shape/dtype/quantization. tensors_to_rewrite: dict[model_editor.Tensor, bytes] = {} - for consumer, info in consumer_pairs: - # Reuse or create ancillary data tensor - if info.tensor not in ancillary_cache: - ancillary_tensor = _create_ancillary_tensor( - info.ancillary_data, - info.tensor, - ) - subgraph.tensors.append(ancillary_tensor) - ancillary_cache[info.tensor] = ancillary_tensor - tensors_to_rewrite[info.tensor] = info.encoded_data - else: - ancillary_tensor = ancillary_cache[info.tensor] - - # Create output tensor (one per DECODE) - output_tensor = _create_output_tensor(info.tensor) - subgraph.tensors.append(output_tensor) - - # Create DECODE operator + for consumer in sorted_consumers: + decode_inputs = [] + decode_outputs = [] + + for info in consumer_to_infos[consumer]: + # Reuse or create ancillary data tensor + if info.tensor not in ancillary_cache: + ancillary_tensor = _create_ancillary_tensor( + info.ancillary_data, + info.tensor, + ) + subgraph.tensors.append(ancillary_tensor) + ancillary_cache[info.tensor] = ancillary_tensor + tensors_to_rewrite[info.tensor] = info.encoded_data + else: + ancillary_tensor = ancillary_cache[info.tensor] + + # Create output tensor (one per compressed input) + output_tensor = _create_output_tensor(info.tensor) + subgraph.tensors.append(output_tensor) + + decode_inputs.extend([info.tensor, ancillary_tensor]) + decode_outputs.append(output_tensor) + + # Rewire this consumer to use the decoded output + _rewire_consumers([consumer], info.tensor, output_tensor) + + # Create single DECODE operator for all compressed inputs decode_op = model_editor.Operator( opcode=tflite.BuiltinOperator.CUSTOM, custom_code=DECODE_CUSTOM_OP_NAME, - inputs=[info.tensor, ancillary_tensor], - outputs=[output_tensor], + inputs=decode_inputs, + outputs=decode_outputs, ) # Insert DECODE immediately before this consumer insert_pos = subgraph.operators.index(consumer) subgraph.operators.insert(insert_pos, decode_op) - # Rewire only this consumer to use the decoded output - _rewire_consumers([consumer], info.tensor, output_tensor) - # Rewrite encoded tensors after all output tensors are created for tensor, encoded_data in tensors_to_rewrite.items(): _rewrite_encoded_tensor(tensor, encoded_data) diff --git a/tensorflow/lite/micro/compression/decode_insert_test.py b/tensorflow/lite/micro/compression/decode_insert_test.py index a7e1fb25e8d..60965b46676 100644 --- a/tensorflow/lite/micro/compression/decode_insert_test.py +++ b/tensorflow/lite/micro/compression/decode_insert_test.py @@ -385,7 +385,6 @@ def test_tensor_naming(self): self.assertEqual(ancillary.name, "weights_ancillary") self.assertEqual(output.name, "weights_decoded") - @unittest.expectedFailure def test_multiple_compressed_inputs_batched(self): """CONCATENATION with two compressed inputs gets one batched DECODE.""" weights_a = model_editor.Tensor( From 18e3a57523388df9f95b003a7c6a9d47993a8203 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:40:14 -0500 Subject: [PATCH 18/19] feat(compression): reject empty compression spec An empty spec list passed to compress() previously returned an unmodified model silently. Fail early with a clear error instead, since an empty spec is almost certainly a mistake. BUG=part of #3256 --- tensorflow/lite/micro/compression/compress.py | 5 +++++ tensorflow/lite/micro/compression/compress_test.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index 270951fecf8..96b55d94fd7 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -132,6 +132,11 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: Returns: A compressed flatbuffer with DECODE operators inserted. """ + specs = list(specs) + if not specs: + raise compressor.CompressionError( + "Compression spec is empty; no tensors to compress") + model = model_editor.read(model_in) compression_results: dict[tuple[int, int], compressor.CompressionResult] = {} diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index cb241c2c62f..6ee80f200d5 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -313,6 +313,11 @@ def test_ancillary_data_format(self): self.assertEqual(dcm_bytes[5] & 0x07, 4) # bitwidth = 4 self.assertEqual(dcm_bytes[6], 4) # stride = num unique values + def test_empty_spec_raises(self): + """Empty compression spec is an error, not a silent no-op.""" + self.assertRaisesRegex(compressor.CompressionError, "empty", + lambda: compress.compress(self.flatbuffer, [])) + def test_smaller_bitwidth_raises(self): """Specifying LUT compression with too small a bitwidth fails.""" specs = [ From 75ca9dcd789e70e335ebe6a6ee30e5e257f6b9e1 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 25 May 2026 00:41:20 -0500 Subject: [PATCH 19/19] docs(python): explain env vars in test runner --- python/tflite_micro/test_compression_unsupported.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tflite_micro/test_compression_unsupported.py b/python/tflite_micro/test_compression_unsupported.py index edd47808298..01c598374ce 100644 --- a/python/tflite_micro/test_compression_unsupported.py +++ b/python/tflite_micro/test_compression_unsupported.py @@ -88,6 +88,8 @@ def test_can_load_regular_after_legacy_failure(self): if __name__ == '__main__': + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + # Disable oneDNN to avoid non-deterministic floating point results os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' tf.test.main()