diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/thread_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/thread_utils.py index 33a9c10dc..737f90149 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/thread_utils.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/synchronization/thread_utils.py @@ -14,9 +14,9 @@ """Utilities for working with threads and asyncio event loops.""" -import asyncio -import threading -from typing import Awaitable, Callable, Generic, TypeVar +from typing import Any, Callable, Coroutine, Generic, TypeVar + +from orbax.checkpoint._src import asyncio_utils T = TypeVar('T') @@ -25,35 +25,22 @@ class BackgroundThreadRunner(Generic[T]): """A runner for an asyncio event loop in a background thread. This class expects an awaitable function that will be run in a background - thread. It creates an event loop that is passed to the thread. This event loop - should only be interacted with via asyncio thread-safe APIs, when tasks are - scheduled from the main thread. + thread and in a dedicated event loop, which are managed by an AsyncRunner + instance. """ def __init__( self, - target: Awaitable[T], + target: Coroutine[Any, Any, T], ): - self._target = target - self._event_loop = asyncio.new_event_loop() - self._thread = threading.Thread( - target=self._event_loop_runner, args=(self._event_loop,), daemon=True - ) - self._thread.start() - self._future = asyncio.run_coroutine_threadsafe( - self._target, self._event_loop - ) - - def _event_loop_runner(self, event_loop: asyncio.AbstractEventLoop): - event_loop.run_forever() - event_loop.close() + self._runner = asyncio_utils.AsyncRunner() + self._future = self._runner.run_coroutine(target) def result(self, timeout: float | None = None) -> T: r = self._future.result(timeout=timeout) - if self._thread: - self._event_loop.call_soon_threadsafe(self._event_loop.stop) - self._thread.join(timeout=timeout) - self._thread = None + if self._runner: + self._runner.shutdown() + self._runner = None return r def on_complete(self, callback: Callable[[T], None]) -> None: