Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion doc_page_extractor/check_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings

import torch

_env_checked = False

Expand All @@ -11,6 +10,16 @@ def check_env() -> None:
return
_env_checked = True

try:
import torch
except ImportError:
warnings.warn(
"This package requires PyTorch to run. Install it with: pip install torch torchvision",
RuntimeWarning,
stacklevel=2,
)
raise

if torch.cuda.is_available():
return
warnings.warn(
Expand Down
6 changes: 1 addition & 5 deletions doc_page_extractor/extraction_context.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from typing import Any, Callable, cast

import torch
from transformers import StoppingCriteria

from .types import ExtractionContext




class ExtractionAbortedError(Exception):
def __init__(self):
super().__init__("Extraction was aborted.")
Expand Down Expand Up @@ -59,7 +55,7 @@ def notify_finished(self):
self._error.output_tokens = self._raw_context.output_tokens
raise self._error

def __call__(self, input_ids, scores, **kwargs) -> torch.BoolTensor:
def __call__(self, input_ids, scores, **kwargs) -> Any:
if self._error:
return cast(Any, True)

Expand Down
12 changes: 4 additions & 8 deletions doc_page_extractor/extractor.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import tempfile

from os import PathLike
from pathlib import Path
from typing import cast, Generator, Iterable

from PIL import Image

from .check_env import check_env
from .model import DeepSeekOCRHugginfaceModel
from .parser import ParsedItemKind, parse_ocr_response
from .redacter import background_color, redact
from .lazy_loader import lazy_load, LazyGetter
from .types import Layout, PageExtractor, ExtractionContext, DeepSeekOCRModel, DeepSeekOCRSize



def create_page_extractor(
model_path: PathLike | str | None = None,
local_only: bool = False,
Expand All @@ -26,6 +24,7 @@ def create_page_extractor(
)
return _PageExtractorImpls(model)


def create_page_extractor_with_model(model: DeepSeekOCRModel) -> PageExtractor:
if not isinstance(model, DeepSeekOCRModel):
raise TypeError("model must implement DeepSeekOCRModel protocol")
Expand All @@ -50,8 +49,6 @@ def extract(
context: ExtractionContext | None = None,
device_number: int | None = None,
) -> Generator[LazyGetter[tuple[Image.Image, list[Layout]]], None, None]:

check_env()
assert stages >= 1, "stages must be at least 1"

image_path = Path(image_path)
Expand All @@ -76,7 +73,8 @@ def extract(
device_number=device_number,
)
extraction_pair = lazy_load(
load=lambda ip=image_path, res=response: self._generate_extraction_pair(ip, res),
load=lambda ip=image_path, res=response: self._generate_extraction_pair(
ip, res),
)
yield extraction_pair

Expand All @@ -93,15 +91,13 @@ def extract(
if temp_dir is not None:
temp_dir.cleanup()


def _generate_extraction_pair(self, image_path: Path, response: str) -> tuple[Image.Image, list[Layout]]:
layouts: list[Layout] = []
image = Image.open(image_path)
for ref, det, text in self._parse_response(image, response):
layouts.append(Layout(ref, det, text))
return image, layouts


def _parse_response(
self, image: Image.Image, response: str
) -> Generator[tuple[str, tuple[int, int, int, int], str | None], None, None]:
Expand Down
76 changes: 43 additions & 33 deletions doc_page_extractor/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from pathlib import Path
from typing import Iterable

import torch
from huggingface_hub import snapshot_download
from readerwriterlock import rwlock
from transformers import AutoModel, AutoTokenizer

from .types import DeepSeekOCRSize
from .check_env import check_env
from .extraction_context import ExtractionContext
from .injection import InferWithInterruption, preprocess_model
from .types import DeepSeekOCRSize


@dataclass
Expand Down Expand Up @@ -57,9 +57,8 @@ def __init__(
self._model_path: Path | None = model_path
self._local_only = local_only
self._models: _Models | None = None
self._device_number_to_index: list[int | None] = self._create_device_number_to_index(
enable_devices_numbers=enable_devices_numbers,
)
self._enable_devices_numbers: Iterable[int] | None = enable_devices_numbers
self._device_number_to_index: list[int | None] | None = None

def download(self, revision: str | None) -> None:
with self._rwlock.gen_wlock():
Expand Down Expand Up @@ -98,9 +97,9 @@ def generate(

models = self._ensure_models()
if device_number is None:
model_index = self._device_number_to_index[0]
model_index = self._get_device_number_to_index()[0]
else:
model_index = self._device_number_to_index[device_number]
model_index = self._get_device_number_to_index()[device_number]

if model_index is None:
raise ValueError(f"Device number {device_number} is not enabled.")
Expand All @@ -125,31 +124,10 @@ def generate(
)
return text_result

def _create_device_number_to_index(self, enable_devices_numbers: Iterable[int] | None) -> list[int | None]:
if not torch.cuda.is_available():
return []

device_count = torch.cuda.device_count()
if enable_devices_numbers is None:
return list(range(device_count))

device_number_to_index: list[int | None] = [None] * device_count
next_model_index: int = 0
for enable_device_number in sorted(list(set(enable_devices_numbers))):
if enable_device_number < 0 or enable_device_number >= device_count:
raise ValueError(
f"Invalid device number {enable_device_number}, "
f"your system has {device_count} CUDA devices."
)
device_number_to_index[enable_device_number] = next_model_index
next_model_index += 1

if next_model_index == 0:
raise ValueError("No devices are enabled for model loading.")

return device_number_to_index

def _ensure_models(self) -> _Models:
check_env()
import torch

with self._rwlock.gen_rlock():
if self._models is not None:
return self._models
Expand All @@ -159,7 +137,8 @@ def _ensure_models(self) -> _Models:
if self._models is not None:
return self._models

if len(self._device_number_to_index) == 0:
device_number_to_index = self._get_device_number_to_index()
if len(device_number_to_index) == 0:
raise RuntimeError("No CUDA devices available")

name_or_path = self._model_name
Expand All @@ -184,7 +163,7 @@ def _ensure_models(self) -> _Models:
local_files_only=self._local_only,
)
llm_models: list[AutoModel] = []
for device_number, model_index in enumerate(self._device_number_to_index):
for device_number, model_index in enumerate(device_number_to_index):
if model_index is None:
continue
model = AutoModel.from_pretrained(
Expand Down Expand Up @@ -231,3 +210,34 @@ def _find_pretrained_path(self) -> str | None:
return None
latest_snapshot = max(snapshot_dirs, key=lambda d: d.stat().st_mtime)
return str(latest_snapshot)

def _get_device_number_to_index(self) -> list[int | None]:
if self._device_number_to_index is None:
import torch
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
if self._enable_devices_numbers is None:
self._device_number_to_index = list(range(device_count))
else:
next_model_index: int = 0
device_number_to_index: list[int | None] = [
None] * device_count
for enable_device_number in sorted(list(set(self._enable_devices_numbers))):
if enable_device_number < 0 or enable_device_number >= device_count:
raise ValueError(
f"Invalid device number {enable_device_number}, "
f"your system has {device_count} CUDA devices."
)
device_number_to_index[enable_device_number] = next_model_index
next_model_index += 1

if next_model_index == 0:
raise ValueError(
"No devices are enabled for model loading.")
self._device_number_to_index = device_number_to_index
else:
self._device_number_to_index = []

self._enable_devices_numbers = None

return self._device_number_to_index