diff --git a/doc_page_extractor/extractor.py b/doc_page_extractor/extractor.py index aa678b1..1588b48 100644 --- a/doc_page_extractor/extractor.py +++ b/doc_page_extractor/extractor.py @@ -1,7 +1,7 @@ import tempfile from os import PathLike from pathlib import Path -from typing import Generator, cast +from typing import cast, Generator, Iterable from PIL import Image @@ -16,10 +16,12 @@ def create_page_extractor( model_path: PathLike | None = None, local_only: bool = False, + enable_devices_numbers: Iterable[int] | None = None, ) -> PageExtractor: model: DeepSeekOCRHugginfaceModel = DeepSeekOCRHugginfaceModel( model_path=Path(model_path) if model_path else None, local_only=local_only, + enable_devices_numbers=enable_devices_numbers, ) return _PageExtractorImpls(model) @@ -45,6 +47,7 @@ def extract( size: DeepSeekOCRSize, stages: int = 1, context: ExtractionContext | None = None, + device_number: int | None = None, ) -> Generator[tuple[Image.Image, list[Layout]], None, None]: check_env() assert stages >= 1, "stages must be at least 1" @@ -57,6 +60,7 @@ def extract( temp_path=temp_path, size=size, context=context, + device_number=device_number, ) layouts: list[Layout] = [] for ref, det, text in self._parse_response(image, response): diff --git a/doc_page_extractor/injection.py b/doc_page_extractor/injection.py index f98ebc8..12ec02b 100644 --- a/doc_page_extractor/injection.py +++ b/doc_page_extractor/injection.py @@ -78,6 +78,7 @@ def thread_safe_generate(*args, **kwargs): return original_generate(*args, **kwargs) model.generate = thread_safe_generate + return model class InferWithInterruption: diff --git a/doc_page_extractor/model.py b/doc_page_extractor/model.py index 346db89..deb543d 100644 --- a/doc_page_extractor/model.py +++ b/doc_page_extractor/model.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from importlib.util import find_spec from pathlib import Path -from typing import Any +from typing import Iterable import torch from huggingface_hub import snapshot_download @@ -37,11 +37,20 @@ class _SizeConfig: else: _ATTN_IMPLEMENTATION = "eager" -_Models = tuple[Any, Any] + +@dataclass +class _Models: + tokenizer: AutoTokenizer + llms: list[AutoModel] class DeepSeekOCRHugginfaceModel: - def __init__(self, model_path: Path | None, local_only: bool) -> None: + def __init__( + self, + model_path: Path | None, + local_only: bool, + enable_devices_numbers: Iterable[int] | None, + ) -> None: if local_only and model_path is None: raise ValueError("model_path must be provided when local_only is True") @@ -50,6 +59,28 @@ def __init__(self, model_path: Path | None, local_only: bool) -> None: self._model_path: Path | None = model_path self._local_only = local_only self._models: _Models | None = None + self._device_number_to_index: list[int | None] + + device_count = torch.cuda.device_count() + if device_count == 0: + raise RuntimeError("No CUDA devices available") + + if enable_devices_numbers is None: + self._device_number_to_index = list(range(device_count)) + else: + self._device_number_to_index = [None] * device_count + next_model_index: int = 0 + for enable_device_number in sorted(list(set(enable_devices_numbers))): + if enable_device_number < 0 or enable_device_number >= device_count: + raise ValueError( + f"Invalid device number {enable_device_number}, " + f"your system has {device_count} CUDA devices." + ) + self._device_number_to_index[enable_device_number] = next_model_index + next_model_index += 1 + + if next_model_index == 0: + raise ValueError("No devices are enabled for model loading.") def download(self) -> None: with self._rwlock.gen_wlock(): @@ -81,13 +112,26 @@ def generate( temp_path: str, size: DeepSeekOCRSize, context: ExtractionContext | None, + device_number: int | None, ) -> str: - tokenizer, model = self._ensure_models() + if device_number is None: + model_index = self._device_number_to_index[0] + else: + model_index = self._device_number_to_index[device_number] + + if model_index is None: + raise ValueError(f"Device number {device_number} is not enabled.") + + models = self._ensure_models() + tokenizer = models.tokenizer + llm_model = models.llms[model_index] + config = _SIZE_CONFIGS[size] + with self._rwlock.gen_rlock(): - config = _SIZE_CONFIGS[size] + # TODO: 支持直接读取图片地址 temp_image_path = os.path.join(temp_path, "temp_image.png") image.save(temp_image_path) - with InferWithInterruption(model, context) as infer: + with InferWithInterruption(llm_model, context) as infer: text_result = infer( tokenizer, prompt=prompt, @@ -133,20 +177,25 @@ def _ensure_models(self) -> _Models: cache_dir=cache_dir, local_files_only=self._local_only, ) - model = AutoModel.from_pretrained( - pretrained_model_name_or_path=name_or_path, - _attn_implementation=_ATTN_IMPLEMENTATION, - trust_remote_code=True, - use_safetensors=True, - cache_dir=cache_dir, - local_files_only=self._local_only, - torch_dtype=torch.bfloat16, - device_map="cuda", - ) - model = model.eval() - self._models = (tokenizer, model) - preprocess_model(model) + llm_models: list[AutoModel] = [] + for device_number, model_index in enumerate(self._device_number_to_index): + if model_index is None: + continue + model = AutoModel.from_pretrained( + pretrained_model_name_or_path=name_or_path, + _attn_implementation=_ATTN_IMPLEMENTATION, + trust_remote_code=True, + use_safetensors=True, + cache_dir=cache_dir, + local_files_only=self._local_only, + ) + model = model.cuda(device_number).to(torch.bfloat16) + llm_models.append(preprocess_model(model)) + self._models = _Models( + tokenizer=tokenizer, + llms=llm_models, + ) return self._models def _cache_dir(self) -> str | None: diff --git a/doc_page_extractor/types.py b/doc_page_extractor/types.py index 8bfd754..8fcdab6 100644 --- a/doc_page_extractor/types.py +++ b/doc_page_extractor/types.py @@ -36,6 +36,7 @@ def extract( size: DeepSeekOCRSize, stages: int = 1, context: ExtractionContext | None = None, + device_number: int | None = None, ) -> Generator[tuple[Image.Image, list[Layout]], None, None]: ... @@ -58,5 +59,6 @@ def generate( temp_path: str, size: DeepSeekOCRSize, context: ExtractionContext | None, + device_number: int | None, ) -> str: ... \ No newline at end of file diff --git a/main.py b/main.py index a7d98d4..96899fc 100644 --- a/main.py +++ b/main.py @@ -16,18 +16,23 @@ def main() -> None: model_path=project_root / "models-cache", local_only=False, ) + begin_at = time.time() + extractor.load_models() + print(f"Models loaded in {time.time() - begin_at:.2f} seconds.") + plot_dir = project_root / "plot" plot_dir.mkdir(exist_ok=True) name_stem = Path(image_name).stem name_suffix = Path(image_name).suffix - created_at = time.time() + begin_at = time.time() def check_aborted() -> bool: - if time.time() - created_at > _ABORT_TIMEOUT: + if time.time() - begin_at > _ABORT_TIMEOUT: print("Aborted extraction due to timeout.") return True return False + print("Starting extraction...") for i, (image, layouts) in enumerate( extractor.extract( image=Image.open(image_dir_path / image_name), @@ -43,7 +48,7 @@ def check_aborted() -> bool: output_path = plot_dir / f"{name_stem}_{i}{name_suffix}" image.save(output_path) - print(f"Extraction cost {time.time() - created_at:.2f} seconds.") + print(f"Extraction cost {time.time() - begin_at:.2f} seconds.") if __name__ == "__main__": main()