Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions frigate/test/test_rknn_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,74 @@

from frigate.util.rknn_converter import is_rknn_compatible

import subprocess
from frigate.util.rknn_converter import ensure_rknn_toolkit

class TestEnsureRknnToolkit(unittest.TestCase):
def setUp(self):
# We need to save the original builtins.__import__ to call it for non-mocked modules
self.original_import = __import__
# Create a dict to track import attempts to test the two-pass logic
self.import_attempts = {"rknn": 0}

def _mock_import(self, name, *args, **kwargs):
if name == "rknn":
self.import_attempts["rknn"] += 1
if self.rknn_import_behavior[self.import_attempts["rknn"] - 1] == "fail":
raise ImportError(f"No module named {name}")
return MagicMock()
return self.original_import(name, *args, **kwargs)

@patch("builtins.__import__")
@patch("frigate.util.rknn_converter.subprocess.check_call")
def test_rknn_already_installed(self, mock_check_call, mock_import):
self.rknn_import_behavior = ["success"]
mock_import.side_effect = self._mock_import

result = ensure_rknn_toolkit()

self.assertTrue(result)
mock_check_call.assert_not_called()
self.assertEqual(self.import_attempts["rknn"], 1)

@patch("builtins.__import__")
@patch("frigate.util.rknn_converter.subprocess.check_call")
def test_rknn_dynamic_install_success(self, mock_check_call, mock_import):
self.rknn_import_behavior = ["fail", "success"]
mock_import.side_effect = self._mock_import

result = ensure_rknn_toolkit()

self.assertTrue(result)
mock_check_call.assert_called_once()
self.assertEqual(self.import_attempts["rknn"], 2)

@patch("builtins.__import__")
@patch("frigate.util.rknn_converter.subprocess.check_call")
def test_rknn_dynamic_install_fail_subprocess(self, mock_check_call, mock_import):
self.rknn_import_behavior = ["fail"]
mock_import.side_effect = self._mock_import
mock_check_call.side_effect = subprocess.CalledProcessError(1, "pip")

result = ensure_rknn_toolkit()

self.assertFalse(result)
mock_check_call.assert_called_once()
self.assertEqual(self.import_attempts["rknn"], 1)

@patch("builtins.__import__")
@patch("frigate.util.rknn_converter.subprocess.check_call")
def test_rknn_dynamic_install_fail_import(self, mock_check_call, mock_import):
self.rknn_import_behavior = ["fail", "fail"]
mock_import.side_effect = self._mock_import

result = ensure_rknn_toolkit()

self.assertFalse(result)
mock_check_call.assert_called_once()
self.assertEqual(self.import_attempts["rknn"], 2)


class TestIsRknnCompatible(unittest.TestCase):
@patch("frigate.util.rknn_converter.get_soc_type")
def test_no_soc(self, mock_get_soc_type):
Expand Down
33 changes: 27 additions & 6 deletions frigate/util/rknn_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,36 @@ def ensure_torch_dependencies() -> bool:


def ensure_rknn_toolkit() -> bool:
"""Ensure RKNN toolkit is available."""
"""Dynamically install rknn-toolkit2 if not available."""
try:
from rknn.api import RKNN # type: ignore # noqa: F401
import rknn # type: ignore # noqa: F401

logger.debug("RKNN toolkit is already available")
logger.debug("RKNN Toolkit is already available")
return True
except ImportError as e:
logger.error(f"RKNN toolkit not found. Please ensure it's installed. {e}")
return False
except ImportError:
logger.info("RKNN Toolkit not found, attempting to install...")

try:
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"--break-system-packages",
"rknn-toolkit2",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)

import rknn # type: ignore # noqa: F401

logger.info("RKNN Toolkit installed successfully")
return True
except (subprocess.CalledProcessError, ImportError) as e:
logger.error(f"Failed to install RKNN Toolkit: {e}")
return False


def get_soc_type() -> Optional[str]:
Expand Down
Loading