From c68f287c7a654e76ecc27e1c9b68d717e4434ac1 Mon Sep 17 00:00:00 2001 From: Tao Zeyu Date: Fri, 28 Nov 2025 12:13:14 +0800 Subject: [PATCH 1/3] feat: support multi-cores GPU --- doc_page_extractor/injection.py | 1 + doc_page_extractor/model.py | 79 ++++++++++++++++------------ doc_page_extractor/resource_locks.py | 55 +++++++++++++++++++ main.py | 11 ++-- 4 files changed, 110 insertions(+), 36 deletions(-) create mode 100644 doc_page_extractor/resource_locks.py diff --git a/doc_page_extractor/injection.py b/doc_page_extractor/injection.py index f98ebc8..12ec02b 100644 --- a/doc_page_extractor/injection.py +++ b/doc_page_extractor/injection.py @@ -78,6 +78,7 @@ def thread_safe_generate(*args, **kwargs): return original_generate(*args, **kwargs) model.generate = thread_safe_generate + return model class InferWithInterruption: diff --git a/doc_page_extractor/model.py b/doc_page_extractor/model.py index 346db89..a607caf 100644 --- a/doc_page_extractor/model.py +++ b/doc_page_extractor/model.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from importlib.util import find_spec from pathlib import Path -from typing import Any import torch from huggingface_hub import snapshot_download @@ -12,6 +11,7 @@ from .extraction_context import ExtractionContext from .injection import InferWithInterruption, preprocess_model +from .resource_locks import ResourceLocks from .types import DeepSeekOCRSize @@ -37,7 +37,11 @@ class _SizeConfig: else: _ATTN_IMPLEMENTATION = "eager" -_Models = tuple[Any, Any] + +@dataclass +class _Models: + tokenizer: AutoTokenizer + llms: ResourceLocks[AutoModel] class DeepSeekOCRHugginfaceModel: @@ -82,25 +86,27 @@ def generate( size: DeepSeekOCRSize, context: ExtractionContext | None, ) -> str: - tokenizer, model = self._ensure_models() + models = self._ensure_models() + tokenizer = models.tokenizer with self._rwlock.gen_rlock(): - config = _SIZE_CONFIGS[size] - temp_image_path = os.path.join(temp_path, "temp_image.png") - image.save(temp_image_path) - with InferWithInterruption(model, context) as infer: - text_result = infer( - tokenizer, - prompt=prompt, - image_file=temp_image_path, - output_path=temp_path, - base_size=config.base_size, - image_size=config.image_size, - crop_mode=config.crop_mode, - save_results=True, - test_compress=True, - eval_mode=True, - ) - return text_result + with models.llms.access() as llm_model: + config = _SIZE_CONFIGS[size] + temp_image_path = os.path.join(temp_path, "temp_image.png") + image.save(temp_image_path) + with InferWithInterruption(llm_model, context) as infer: + text_result = infer( + tokenizer, + prompt=prompt, + image_file=temp_image_path, + output_path=temp_path, + base_size=config.base_size, + image_size=config.image_size, + crop_mode=config.crop_mode, + save_results=True, + test_compress=True, + eval_mode=True, + ) + return text_result def _ensure_models(self) -> _Models: with self._rwlock.gen_rlock(): @@ -133,20 +139,27 @@ def _ensure_models(self) -> _Models: cache_dir=cache_dir, local_files_only=self._local_only, ) - model = AutoModel.from_pretrained( - pretrained_model_name_or_path=name_or_path, - _attn_implementation=_ATTN_IMPLEMENTATION, - trust_remote_code=True, - use_safetensors=True, - cache_dir=cache_dir, - local_files_only=self._local_only, - torch_dtype=torch.bfloat16, - device_map="cuda", - ) - model = model.eval() - self._models = (tokenizer, model) - preprocess_model(model) + device_count = torch.cuda.device_count() + if device_count == 0: + raise RuntimeError("No CUDA devices available") + + llm_models: list[AutoModel] = [] + for i in range(0, device_count): + model = AutoModel.from_pretrained( + pretrained_model_name_or_path=name_or_path, + _attn_implementation=_ATTN_IMPLEMENTATION, + trust_remote_code=True, + use_safetensors=True, + cache_dir=cache_dir, + local_files_only=self._local_only, + ) + model = model.cuda(i).to(torch.bfloat16) + llm_models.append(preprocess_model(model)) + self._models = _Models( + tokenizer=tokenizer, + llms=ResourceLocks(llm_models), + ) return self._models def _cache_dir(self) -> str | None: diff --git a/doc_page_extractor/resource_locks.py b/doc_page_extractor/resource_locks.py new file mode 100644 index 0000000..353515c --- /dev/null +++ b/doc_page_extractor/resource_locks.py @@ -0,0 +1,55 @@ +import threading +from dataclasses import dataclass +from typing import Generic, TypeVar + +T = TypeVar("T") + + +@dataclass +class _Node(Generic[T]): + resource: T + lock: threading.Lock + + +class ResourceLock(Generic[T]): + def __init__(self, resource: T, lock: threading.Lock) -> None: + self._resource = resource + self._lock = lock + + @property + def resource(self) -> T: + return self._resource + + def __enter__(self) -> T: + self._lock.acquire() + return self._resource + + def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 + self._lock.release() + + +class ResourceLocks(Generic[T]): + def __init__(self, resources: list[T]) -> None: + if not resources: + raise ValueError("resources must not be empty") + + self._nodes = [_Node(resource=r, lock=threading.Lock()) for r in resources] + self._next_index = 0 + self._index_lock = threading.Lock() + + def access(self) -> ResourceLock[T]: + # TODO: 这是个简单的轮询逻辑,无法做到先到先得(只能基本做到)有优化空间 + with self._index_lock: + start_index = self._next_index + self._next_index = (self._next_index + 1) % len(self._nodes) + + for offset in range(len(self._nodes)): + index = (start_index + offset) % len(self._nodes) + node = self._nodes[index] + + if node.lock.acquire(blocking=False): + return ResourceLock(node.resource, node.lock) + + node = self._nodes[start_index] + node.lock.acquire() + return ResourceLock(node.resource, node.lock) diff --git a/main.py b/main.py index a7d98d4..96899fc 100644 --- a/main.py +++ b/main.py @@ -16,18 +16,23 @@ def main() -> None: model_path=project_root / "models-cache", local_only=False, ) + begin_at = time.time() + extractor.load_models() + print(f"Models loaded in {time.time() - begin_at:.2f} seconds.") + plot_dir = project_root / "plot" plot_dir.mkdir(exist_ok=True) name_stem = Path(image_name).stem name_suffix = Path(image_name).suffix - created_at = time.time() + begin_at = time.time() def check_aborted() -> bool: - if time.time() - created_at > _ABORT_TIMEOUT: + if time.time() - begin_at > _ABORT_TIMEOUT: print("Aborted extraction due to timeout.") return True return False + print("Starting extraction...") for i, (image, layouts) in enumerate( extractor.extract( image=Image.open(image_dir_path / image_name), @@ -43,7 +48,7 @@ def check_aborted() -> bool: output_path = plot_dir / f"{name_stem}_{i}{name_suffix}" image.save(output_path) - print(f"Extraction cost {time.time() - created_at:.2f} seconds.") + print(f"Extraction cost {time.time() - begin_at:.2f} seconds.") if __name__ == "__main__": main() From 1b80d420a1c77b72149687ccf4671e6f06523913 Mon Sep 17 00:00:00 2001 From: Tao Zeyu Date: Fri, 28 Nov 2025 12:22:46 +0800 Subject: [PATCH 2/3] feat: don't use locks (user should implement it if they want) --- doc_page_extractor/model.py | 42 ++++++++++----------- doc_page_extractor/resource_locks.py | 55 ---------------------------- 2 files changed, 21 insertions(+), 76 deletions(-) delete mode 100644 doc_page_extractor/resource_locks.py diff --git a/doc_page_extractor/model.py b/doc_page_extractor/model.py index a607caf..225faa6 100644 --- a/doc_page_extractor/model.py +++ b/doc_page_extractor/model.py @@ -11,7 +11,6 @@ from .extraction_context import ExtractionContext from .injection import InferWithInterruption, preprocess_model -from .resource_locks import ResourceLocks from .types import DeepSeekOCRSize @@ -41,7 +40,7 @@ class _SizeConfig: @dataclass class _Models: tokenizer: AutoTokenizer - llms: ResourceLocks[AutoModel] + llms: list[AutoModel] class DeepSeekOCRHugginfaceModel: @@ -89,24 +88,25 @@ def generate( models = self._ensure_models() tokenizer = models.tokenizer with self._rwlock.gen_rlock(): - with models.llms.access() as llm_model: - config = _SIZE_CONFIGS[size] - temp_image_path = os.path.join(temp_path, "temp_image.png") - image.save(temp_image_path) - with InferWithInterruption(llm_model, context) as infer: - text_result = infer( - tokenizer, - prompt=prompt, - image_file=temp_image_path, - output_path=temp_path, - base_size=config.base_size, - image_size=config.image_size, - crop_mode=config.crop_mode, - save_results=True, - test_compress=True, - eval_mode=True, - ) - return text_result + llm_model = models.llms[0] + config = _SIZE_CONFIGS[size] + # TODO: 支持直接读取图片地址 + temp_image_path = os.path.join(temp_path, "temp_image.png") + image.save(temp_image_path) + with InferWithInterruption(llm_model, context) as infer: + text_result = infer( + tokenizer, + prompt=prompt, + image_file=temp_image_path, + output_path=temp_path, + base_size=config.base_size, + image_size=config.image_size, + crop_mode=config.crop_mode, + save_results=True, + test_compress=True, + eval_mode=True, + ) + return text_result def _ensure_models(self) -> _Models: with self._rwlock.gen_rlock(): @@ -158,7 +158,7 @@ def _ensure_models(self) -> _Models: self._models = _Models( tokenizer=tokenizer, - llms=ResourceLocks(llm_models), + llms=llm_models, ) return self._models diff --git a/doc_page_extractor/resource_locks.py b/doc_page_extractor/resource_locks.py deleted file mode 100644 index 353515c..0000000 --- a/doc_page_extractor/resource_locks.py +++ /dev/null @@ -1,55 +0,0 @@ -import threading -from dataclasses import dataclass -from typing import Generic, TypeVar - -T = TypeVar("T") - - -@dataclass -class _Node(Generic[T]): - resource: T - lock: threading.Lock - - -class ResourceLock(Generic[T]): - def __init__(self, resource: T, lock: threading.Lock) -> None: - self._resource = resource - self._lock = lock - - @property - def resource(self) -> T: - return self._resource - - def __enter__(self) -> T: - self._lock.acquire() - return self._resource - - def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 - self._lock.release() - - -class ResourceLocks(Generic[T]): - def __init__(self, resources: list[T]) -> None: - if not resources: - raise ValueError("resources must not be empty") - - self._nodes = [_Node(resource=r, lock=threading.Lock()) for r in resources] - self._next_index = 0 - self._index_lock = threading.Lock() - - def access(self) -> ResourceLock[T]: - # TODO: 这是个简单的轮询逻辑,无法做到先到先得(只能基本做到)有优化空间 - with self._index_lock: - start_index = self._next_index - self._next_index = (self._next_index + 1) % len(self._nodes) - - for offset in range(len(self._nodes)): - index = (start_index + offset) % len(self._nodes) - node = self._nodes[index] - - if node.lock.acquire(blocking=False): - return ResourceLock(node.resource, node.lock) - - node = self._nodes[start_index] - node.lock.acquire() - return ResourceLock(node.resource, node.lock) From 804fb7e44b182e6374d2e87a0f51ac90b0999f6b Mon Sep 17 00:00:00 2001 From: Tao Zeyu Date: Fri, 28 Nov 2025 12:46:22 +0800 Subject: [PATCH 3/3] feat: support enable devices list --- doc_page_extractor/extractor.py | 6 +++- doc_page_extractor/model.py | 54 +++++++++++++++++++++++++++------ doc_page_extractor/types.py | 2 ++ 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/doc_page_extractor/extractor.py b/doc_page_extractor/extractor.py index aa678b1..1588b48 100644 --- a/doc_page_extractor/extractor.py +++ b/doc_page_extractor/extractor.py @@ -1,7 +1,7 @@ import tempfile from os import PathLike from pathlib import Path -from typing import Generator, cast +from typing import cast, Generator, Iterable from PIL import Image @@ -16,10 +16,12 @@ def create_page_extractor( model_path: PathLike | None = None, local_only: bool = False, + enable_devices_numbers: Iterable[int] | None = None, ) -> PageExtractor: model: DeepSeekOCRHugginfaceModel = DeepSeekOCRHugginfaceModel( model_path=Path(model_path) if model_path else None, local_only=local_only, + enable_devices_numbers=enable_devices_numbers, ) return _PageExtractorImpls(model) @@ -45,6 +47,7 @@ def extract( size: DeepSeekOCRSize, stages: int = 1, context: ExtractionContext | None = None, + device_number: int | None = None, ) -> Generator[tuple[Image.Image, list[Layout]], None, None]: check_env() assert stages >= 1, "stages must be at least 1" @@ -57,6 +60,7 @@ def extract( temp_path=temp_path, size=size, context=context, + device_number=device_number, ) layouts: list[Layout] = [] for ref, det, text in self._parse_response(image, response): diff --git a/doc_page_extractor/model.py b/doc_page_extractor/model.py index 225faa6..deb543d 100644 --- a/doc_page_extractor/model.py +++ b/doc_page_extractor/model.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from importlib.util import find_spec from pathlib import Path +from typing import Iterable import torch from huggingface_hub import snapshot_download @@ -44,7 +45,12 @@ class _Models: class DeepSeekOCRHugginfaceModel: - def __init__(self, model_path: Path | None, local_only: bool) -> None: + def __init__( + self, + model_path: Path | None, + local_only: bool, + enable_devices_numbers: Iterable[int] | None, + ) -> None: if local_only and model_path is None: raise ValueError("model_path must be provided when local_only is True") @@ -53,6 +59,28 @@ def __init__(self, model_path: Path | None, local_only: bool) -> None: self._model_path: Path | None = model_path self._local_only = local_only self._models: _Models | None = None + self._device_number_to_index: list[int | None] + + device_count = torch.cuda.device_count() + if device_count == 0: + raise RuntimeError("No CUDA devices available") + + if enable_devices_numbers is None: + self._device_number_to_index = list(range(device_count)) + else: + self._device_number_to_index = [None] * device_count + next_model_index: int = 0 + for enable_device_number in sorted(list(set(enable_devices_numbers))): + if enable_device_number < 0 or enable_device_number >= device_count: + raise ValueError( + f"Invalid device number {enable_device_number}, " + f"your system has {device_count} CUDA devices." + ) + self._device_number_to_index[enable_device_number] = next_model_index + next_model_index += 1 + + if next_model_index == 0: + raise ValueError("No devices are enabled for model loading.") def download(self) -> None: with self._rwlock.gen_wlock(): @@ -84,12 +112,22 @@ def generate( temp_path: str, size: DeepSeekOCRSize, context: ExtractionContext | None, + device_number: int | None, ) -> str: + if device_number is None: + model_index = self._device_number_to_index[0] + else: + model_index = self._device_number_to_index[device_number] + + if model_index is None: + raise ValueError(f"Device number {device_number} is not enabled.") + models = self._ensure_models() tokenizer = models.tokenizer + llm_model = models.llms[model_index] + config = _SIZE_CONFIGS[size] + with self._rwlock.gen_rlock(): - llm_model = models.llms[0] - config = _SIZE_CONFIGS[size] # TODO: 支持直接读取图片地址 temp_image_path = os.path.join(temp_path, "temp_image.png") image.save(temp_image_path) @@ -139,12 +177,10 @@ def _ensure_models(self) -> _Models: cache_dir=cache_dir, local_files_only=self._local_only, ) - device_count = torch.cuda.device_count() - if device_count == 0: - raise RuntimeError("No CUDA devices available") - llm_models: list[AutoModel] = [] - for i in range(0, device_count): + for device_number, model_index in enumerate(self._device_number_to_index): + if model_index is None: + continue model = AutoModel.from_pretrained( pretrained_model_name_or_path=name_or_path, _attn_implementation=_ATTN_IMPLEMENTATION, @@ -153,7 +189,7 @@ def _ensure_models(self) -> _Models: cache_dir=cache_dir, local_files_only=self._local_only, ) - model = model.cuda(i).to(torch.bfloat16) + model = model.cuda(device_number).to(torch.bfloat16) llm_models.append(preprocess_model(model)) self._models = _Models( diff --git a/doc_page_extractor/types.py b/doc_page_extractor/types.py index 8bfd754..8fcdab6 100644 --- a/doc_page_extractor/types.py +++ b/doc_page_extractor/types.py @@ -36,6 +36,7 @@ def extract( size: DeepSeekOCRSize, stages: int = 1, context: ExtractionContext | None = None, + device_number: int | None = None, ) -> Generator[tuple[Image.Image, list[Layout]], None, None]: ... @@ -58,5 +59,6 @@ def generate( temp_path: str, size: DeepSeekOCRSize, context: ExtractionContext | None, + device_number: int | None, ) -> str: ... \ No newline at end of file