Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

"""Utilities for working with threads and asyncio event loops."""

import asyncio
import threading
from typing import Awaitable, Callable, Generic, TypeVar
from typing import Any, Callable, Coroutine, Generic, TypeVar

from orbax.checkpoint._src import asyncio_utils

T = TypeVar('T')

Expand All @@ -25,35 +25,22 @@ class BackgroundThreadRunner(Generic[T]):
"""A runner for an asyncio event loop in a background thread.

This class expects an awaitable function that will be run in a background
thread. It creates an event loop that is passed to the thread. This event loop
should only be interacted with via asyncio thread-safe APIs, when tasks are
scheduled from the main thread.
thread and in a dedicated event loop, which are managed by an AsyncRunner
instance.
"""

def __init__(
self,
target: Awaitable[T],
target: Coroutine[Any, Any, T],
):
self._target = target
self._event_loop = asyncio.new_event_loop()
self._thread = threading.Thread(
target=self._event_loop_runner, args=(self._event_loop,), daemon=True
)
self._thread.start()
self._future = asyncio.run_coroutine_threadsafe(
self._target, self._event_loop
)

def _event_loop_runner(self, event_loop: asyncio.AbstractEventLoop):
event_loop.run_forever()
event_loop.close()
self._runner = asyncio_utils.AsyncRunner()
self._future = self._runner.run_coroutine(target)

def result(self, timeout: float | None = None) -> T:
r = self._future.result(timeout=timeout)
if self._thread:
self._event_loop.call_soon_threadsafe(self._event_loop.stop)
self._thread.join(timeout=timeout)
self._thread = None
if self._runner:
self._runner.shutdown()
self._runner = None
return r

def on_complete(self, callback: Callable[[T], None]) -> None:
Expand Down
Loading