diff --git a/doc_page_extractor/check_env.py b/doc_page_extractor/check_env.py index db3cd12..7654b3f 100644 --- a/doc_page_extractor/check_env.py +++ b/doc_page_extractor/check_env.py @@ -1,6 +1,5 @@ import warnings -import torch _env_checked = False @@ -11,6 +10,16 @@ def check_env() -> None: return _env_checked = True + try: + import torch + except ImportError: + warnings.warn( + "This package requires PyTorch to run. Install it with: pip install torch torchvision", + RuntimeWarning, + stacklevel=2, + ) + raise + if torch.cuda.is_available(): return warnings.warn( diff --git a/doc_page_extractor/extraction_context.py b/doc_page_extractor/extraction_context.py index 73282c7..1f413bd 100644 --- a/doc_page_extractor/extraction_context.py +++ b/doc_page_extractor/extraction_context.py @@ -1,13 +1,9 @@ from typing import Any, Callable, cast - -import torch from transformers import StoppingCriteria from .types import ExtractionContext - - class ExtractionAbortedError(Exception): def __init__(self): super().__init__("Extraction was aborted.") @@ -59,7 +55,7 @@ def notify_finished(self): self._error.output_tokens = self._raw_context.output_tokens raise self._error - def __call__(self, input_ids, scores, **kwargs) -> torch.BoolTensor: + def __call__(self, input_ids, scores, **kwargs) -> Any: if self._error: return cast(Any, True) diff --git a/doc_page_extractor/extractor.py b/doc_page_extractor/extractor.py index 1cbeed7..4bd6849 100644 --- a/doc_page_extractor/extractor.py +++ b/doc_page_extractor/extractor.py @@ -1,11 +1,10 @@ import tempfile + from os import PathLike from pathlib import Path from typing import cast, Generator, Iterable - from PIL import Image -from .check_env import check_env from .model import DeepSeekOCRHugginfaceModel from .parser import ParsedItemKind, parse_ocr_response from .redacter import background_color, redact @@ -13,7 +12,6 @@ from .types import Layout, PageExtractor, ExtractionContext, DeepSeekOCRModel, DeepSeekOCRSize - def create_page_extractor( model_path: PathLike | str | None = None, local_only: bool = False, @@ -26,6 +24,7 @@ def create_page_extractor( ) return _PageExtractorImpls(model) + def create_page_extractor_with_model(model: DeepSeekOCRModel) -> PageExtractor: if not isinstance(model, DeepSeekOCRModel): raise TypeError("model must implement DeepSeekOCRModel protocol") @@ -50,8 +49,6 @@ def extract( context: ExtractionContext | None = None, device_number: int | None = None, ) -> Generator[LazyGetter[tuple[Image.Image, list[Layout]]], None, None]: - - check_env() assert stages >= 1, "stages must be at least 1" image_path = Path(image_path) @@ -76,7 +73,8 @@ def extract( device_number=device_number, ) extraction_pair = lazy_load( - load=lambda ip=image_path, res=response: self._generate_extraction_pair(ip, res), + load=lambda ip=image_path, res=response: self._generate_extraction_pair( + ip, res), ) yield extraction_pair @@ -93,7 +91,6 @@ def extract( if temp_dir is not None: temp_dir.cleanup() - def _generate_extraction_pair(self, image_path: Path, response: str) -> tuple[Image.Image, list[Layout]]: layouts: list[Layout] = [] image = Image.open(image_path) @@ -101,7 +98,6 @@ def _generate_extraction_pair(self, image_path: Path, response: str) -> tuple[Im layouts.append(Layout(ref, det, text)) return image, layouts - def _parse_response( self, image: Image.Image, response: str ) -> Generator[tuple[str, tuple[int, int, int, int], str | None], None, None]: diff --git a/doc_page_extractor/model.py b/doc_page_extractor/model.py index 2703518..cf5099f 100644 --- a/doc_page_extractor/model.py +++ b/doc_page_extractor/model.py @@ -3,14 +3,14 @@ from pathlib import Path from typing import Iterable -import torch from huggingface_hub import snapshot_download from readerwriterlock import rwlock from transformers import AutoModel, AutoTokenizer +from .types import DeepSeekOCRSize +from .check_env import check_env from .extraction_context import ExtractionContext from .injection import InferWithInterruption, preprocess_model -from .types import DeepSeekOCRSize @dataclass @@ -57,9 +57,8 @@ def __init__( self._model_path: Path | None = model_path self._local_only = local_only self._models: _Models | None = None - self._device_number_to_index: list[int | None] = self._create_device_number_to_index( - enable_devices_numbers=enable_devices_numbers, - ) + self._enable_devices_numbers: Iterable[int] | None = enable_devices_numbers + self._device_number_to_index: list[int | None] | None = None def download(self, revision: str | None) -> None: with self._rwlock.gen_wlock(): @@ -98,9 +97,9 @@ def generate( models = self._ensure_models() if device_number is None: - model_index = self._device_number_to_index[0] + model_index = self._get_device_number_to_index()[0] else: - model_index = self._device_number_to_index[device_number] + model_index = self._get_device_number_to_index()[device_number] if model_index is None: raise ValueError(f"Device number {device_number} is not enabled.") @@ -125,31 +124,10 @@ def generate( ) return text_result - def _create_device_number_to_index(self, enable_devices_numbers: Iterable[int] | None) -> list[int | None]: - if not torch.cuda.is_available(): - return [] - - device_count = torch.cuda.device_count() - if enable_devices_numbers is None: - return list(range(device_count)) - - device_number_to_index: list[int | None] = [None] * device_count - next_model_index: int = 0 - for enable_device_number in sorted(list(set(enable_devices_numbers))): - if enable_device_number < 0 or enable_device_number >= device_count: - raise ValueError( - f"Invalid device number {enable_device_number}, " - f"your system has {device_count} CUDA devices." - ) - device_number_to_index[enable_device_number] = next_model_index - next_model_index += 1 - - if next_model_index == 0: - raise ValueError("No devices are enabled for model loading.") - - return device_number_to_index - def _ensure_models(self) -> _Models: + check_env() + import torch + with self._rwlock.gen_rlock(): if self._models is not None: return self._models @@ -159,7 +137,8 @@ def _ensure_models(self) -> _Models: if self._models is not None: return self._models - if len(self._device_number_to_index) == 0: + device_number_to_index = self._get_device_number_to_index() + if len(device_number_to_index) == 0: raise RuntimeError("No CUDA devices available") name_or_path = self._model_name @@ -184,7 +163,7 @@ def _ensure_models(self) -> _Models: local_files_only=self._local_only, ) llm_models: list[AutoModel] = [] - for device_number, model_index in enumerate(self._device_number_to_index): + for device_number, model_index in enumerate(device_number_to_index): if model_index is None: continue model = AutoModel.from_pretrained( @@ -231,3 +210,34 @@ def _find_pretrained_path(self) -> str | None: return None latest_snapshot = max(snapshot_dirs, key=lambda d: d.stat().st_mtime) return str(latest_snapshot) + + def _get_device_number_to_index(self) -> list[int | None]: + if self._device_number_to_index is None: + import torch + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + if self._enable_devices_numbers is None: + self._device_number_to_index = list(range(device_count)) + else: + next_model_index: int = 0 + device_number_to_index: list[int | None] = [ + None] * device_count + for enable_device_number in sorted(list(set(self._enable_devices_numbers))): + if enable_device_number < 0 or enable_device_number >= device_count: + raise ValueError( + f"Invalid device number {enable_device_number}, " + f"your system has {device_count} CUDA devices." + ) + device_number_to_index[enable_device_number] = next_model_index + next_model_index += 1 + + if next_model_index == 0: + raise ValueError( + "No devices are enabled for model loading.") + self._device_number_to_index = device_number_to_index + else: + self._device_number_to_index = [] + + self._enable_devices_numbers = None + + return self._device_number_to_index