feat(export): add TFLite export support with FP32/FP16/INT8#920
feat(export): add TFLite export support with FP32/FP16/INT8#920mfazrinizar wants to merge 4 commits intoroboflow:developfrom
Conversation
- Use getattr/setattr for dynamic module attribute access instead of direct attribute assignment (attr-defined) - Add unused-ignore to type: ignore comments for numpy.load monkey-patch to handle both mypy versions (unused-ignore)
There was a problem hiding this comment.
Pull request overview
Adds a new ONNX → TFLite export path (via onnx2tf) to support deploying RF-DETR detection/segmentation models to TensorFlow Lite with FP32/FP16/INT8 options, including calibration-data handling.
Changes:
- Introduces
src/rfdetr/export/_tflite/with anexport_tflite()converter and calibration-data preparation utilities. - Extends
RFDETR.export()to routeformat="tflite"and forward quantization/calibration parameters. - Adds/updates docs, optional dependencies, and a comprehensive new unit test suite for the TFLite pipeline.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
tests/export/test_tflite_export.py |
Adds extensive unit tests for the new TFLite export pipeline and parameter wiring. |
tests/export/__init__.py |
Adds package marker file for export tests. |
src/rfdetr/export/main.py |
Adds a post-ONNX step to run export_tflite() when args.tflite is present. |
src/rfdetr/export/_tflite/converter.py |
Implements ONNX → TFLite conversion, calibration preparation, and onnx2tf compatibility/patching. |
src/rfdetr/export/_tflite/__init__.py |
Exposes export_tflite at the subpackage level. |
src/rfdetr/export/__init__.py |
Updates module docstring to mention TFLite. |
src/rfdetr/detr.py |
Adds format="tflite" export routing + quantization/calibration parameters and docs. |
pyproject.toml |
Adds tflite optional dependency group and a tflite pytest marker. |
docs/learn/export.md |
Updates export docs to include TFLite usage + calibration/quantization guidance. |
| except ImportError: | ||
| raise ImportError("onnx2tf is not installed. Install it with: pip install rfdetr[tflite]") |
There was a problem hiding this comment.
The ImportError message recommends pip install rfdetr[tflite], but the docs and RFDETR.export(..., format="tflite") route require ONNX export dependencies as well (rfdetr[onnx,tflite]). Also, the default calibration path (calibration_data=None) calls _get_onnx_input_info() and will error if onnx isn’t installed. Please align the message (and/or add an explicit ONNX availability check with a clear install hint) to avoid misleading users.
| except ImportError: | |
| raise ImportError("onnx2tf is not installed. Install it with: pip install rfdetr[tflite]") | |
| except ImportError as exc: | |
| raise ImportError( | |
| "onnx2tf is not installed. TFLite export requires both ONNX and " | |
| "TFLite export dependencies. Install them with: " | |
| "pip install rfdetr[onnx,tflite]" | |
| ) from exc |
| onnx = [ | ||
| "onnx>=1.16.0,<1.20", | ||
| "onnxsim<0.6.0", # TODO: onnxsim 0.6.0+ hangs on install | ||
| "onnx_graphsurgeon", | ||
| "onnxruntime", | ||
| "polygraphy", | ||
| ] | ||
| trt = [ | ||
| "pycuda", | ||
| "onnxruntime-gpu", | ||
| "tensorrt>=8.6.1", | ||
| "polygraphy", | ||
| ] | ||
| tflite = [ | ||
| "onnx2tf>=1.25.0,<3.0.0", | ||
| "flatbuffers>=23.5.26", | ||
| ] |
There was a problem hiding this comment.
PR description says onnx2tf>=1.26.0 and tf-keras>=2.16.0 were added to the [onnx] extra, but this diff introduces a new [tflite] extra with onnx2tf>=1.25.0,<3.0.0 + flatbuffers and does not add tf-keras. Please update either the dependency list (if tf-keras/version pins are actually required) or the PR description/docs so they match the shipped extras.
| def test_reads_input_name_and_shape(self, tmp_path: Path) -> None: | ||
| """Build a minimal ONNX model and verify we read back its metadata.""" | ||
| import onnx | ||
| from onnx import TensorProto, helper | ||
|
|
||
| inp = helper.make_tensor_value_info("images", TensorProto.FLOAT, [1, 3, 560, 560]) | ||
| out = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 100, 4]) | ||
| node = helper.make_node("Identity", inputs=["images"], outputs=["output"]) | ||
| graph = helper.make_graph([node], "test", [inp], [out]) | ||
| model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) | ||
| onnx_path = tmp_path / "test_model.onnx" | ||
| onnx.save(model, str(onnx_path)) | ||
|
|
||
| name, dims = _get_onnx_input_info(onnx_path) | ||
| assert name == "images" | ||
| assert dims == [1, 3, 560, 560] | ||
|
|
||
| def test_different_input_shape(self, tmp_path: Path) -> None: | ||
| """Verify non-square resolution reads correctly.""" | ||
| import onnx | ||
| from onnx import TensorProto, helper | ||
|
|
||
| inp = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 448, 640]) | ||
| out = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 10, 4]) | ||
| node = helper.make_node("Identity", inputs=["input"], outputs=["output"]) | ||
| graph = helper.make_graph([node], "test", [inp], [out]) | ||
| model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) | ||
| onnx_path = tmp_path / "test_model.onnx" | ||
| onnx.save(model, str(onnx_path)) | ||
|
|
||
| name, dims = _get_onnx_input_info(onnx_path) | ||
| assert name == "input" | ||
| assert dims == [1, 3, 448, 640] |
There was a problem hiding this comment.
These tests import onnx directly, but ONNX is an optional extra and the default CPU CI install (editable install without --extra onnx) will not have it, causing ModuleNotFoundError and failing the suite. Consider using pytest.importorskip("onnx") for these ONNX-dependent tests, or mocking/injecting a minimal onnx module (similar to existing export tests) so the base test run doesn’t require optional dependencies.
Codecov Report❌ Patch coverage is ❌ Your patch check has failed because the patch coverage (19%) is below the target coverage (95%). You can increase the patch coverage or adjust the target coverage.
Additional details and impacted files@@ Coverage Diff @@
## develop #920 +/- ##
========================================
- Coverage 79% 56% -22%
========================================
Files 97 99 +2
Lines 7817 7964 +147
========================================
- Hits 6167 4499 -1668
- Misses 1650 3465 +1815 🚀 New features to boost your workflow:
|
What does this PR do?
Adds an ONNX → TFLite conversion pipeline for RF-DETR detection and segmentation models via
onnx2tf.Users can now export any RF-DETR model to TFLite format with a single call:
Key changes
New module:
src/rfdetr/export/_tflite/converter.py: Coreexport_tflite()function — converts ONNX to TFLite via onnx2tf Python APIquantization="int8"with calibration from an image directory,.npyfile, or NumPy arraymax_imagesparameter controls calibration sample count (default: 100)output_signaturedefs=Truefix for segmentation models (onnx2tf node naming issue with leading/characters violating TF saved_model naming pattern)onnx2tf'sdownload_test_image_data()to use local calibration data instead of downloading from GitHub (which fails in CI, firewalls, air-gapped systems)allow_picklecompatibility shim for onnx2tf 1.x releasesModified files
src/rfdetr/detr.py: Addedformat="tflite"routing inexport()withquantization,calibration_data,max_images,output_dirparameterssrc/rfdetr/export/main.py/__init__.py: Wired upexport_tfliteimportpyproject.toml: Addedonnx2tf>=1.26.0andtf-keras>=2.16.0to[onnx]optional dependenciesdocs/learn/export.md: TFLite usage guide, calibration explanation, parameter reference, FP32/FP16 always-produced behavior noteRelated Issue(s): Related to #173 and this PR
Type of Change
Testing
Test details:
56 unit tests in
tests/export/test_tflite_export.py(all CPU, no GPU required):export(format="tflite")correctly dispatches toexport_tfliteNone,"fp32","fp16","int8"modes; rejects invalid values.npyfile loading, NumPy array input, and directory-based image loading with resizemax_imagesparameter forwarding: Verifiesmax_imagesflows throughdetr.py→export_tflite()→_prepare_calibration_data()→_load_calibration_images()Checklist