Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion doc_page_extractor/extractor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tempfile
from os import PathLike
from pathlib import Path
from typing import Generator, cast
from typing import cast, Generator, Iterable

from PIL import Image

Expand All @@ -16,10 +16,12 @@
def create_page_extractor(
model_path: PathLike | None = None,
local_only: bool = False,
enable_devices_numbers: Iterable[int] | None = None,
) -> PageExtractor:
model: DeepSeekOCRHugginfaceModel = DeepSeekOCRHugginfaceModel(
model_path=Path(model_path) if model_path else None,
local_only=local_only,
enable_devices_numbers=enable_devices_numbers,
)
return _PageExtractorImpls(model)

Expand All @@ -45,6 +47,7 @@ def extract(
size: DeepSeekOCRSize,
stages: int = 1,
context: ExtractionContext | None = None,
device_number: int | None = None,
) -> Generator[tuple[Image.Image, list[Layout]], None, None]:
check_env()
assert stages >= 1, "stages must be at least 1"
Expand All @@ -57,6 +60,7 @@ def extract(
temp_path=temp_path,
size=size,
context=context,
device_number=device_number,
)
layouts: list[Layout] = []
for ref, det, text in self._parse_response(image, response):
Expand Down
1 change: 1 addition & 0 deletions doc_page_extractor/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def thread_safe_generate(*args, **kwargs):
return original_generate(*args, **kwargs)

model.generate = thread_safe_generate
return model


class InferWithInterruption:
Expand Down
87 changes: 68 additions & 19 deletions doc_page_extractor/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from importlib.util import find_spec
from pathlib import Path
from typing import Any
from typing import Iterable

import torch
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -37,11 +37,20 @@ class _SizeConfig:
else:
_ATTN_IMPLEMENTATION = "eager"

_Models = tuple[Any, Any]

@dataclass
class _Models:
tokenizer: AutoTokenizer
llms: list[AutoModel]


class DeepSeekOCRHugginfaceModel:
def __init__(self, model_path: Path | None, local_only: bool) -> None:
def __init__(
self,
model_path: Path | None,
local_only: bool,
enable_devices_numbers: Iterable[int] | None,
) -> None:
if local_only and model_path is None:
raise ValueError("model_path must be provided when local_only is True")

Expand All @@ -50,6 +59,28 @@ def __init__(self, model_path: Path | None, local_only: bool) -> None:
self._model_path: Path | None = model_path
self._local_only = local_only
self._models: _Models | None = None
self._device_number_to_index: list[int | None]

device_count = torch.cuda.device_count()
if device_count == 0:
raise RuntimeError("No CUDA devices available")

if enable_devices_numbers is None:
self._device_number_to_index = list(range(device_count))
else:
self._device_number_to_index = [None] * device_count
next_model_index: int = 0
for enable_device_number in sorted(list(set(enable_devices_numbers))):
if enable_device_number < 0 or enable_device_number >= device_count:
raise ValueError(
f"Invalid device number {enable_device_number}, "
f"your system has {device_count} CUDA devices."
)
self._device_number_to_index[enable_device_number] = next_model_index
next_model_index += 1

if next_model_index == 0:
raise ValueError("No devices are enabled for model loading.")

def download(self) -> None:
with self._rwlock.gen_wlock():
Expand Down Expand Up @@ -81,13 +112,26 @@ def generate(
temp_path: str,
size: DeepSeekOCRSize,
context: ExtractionContext | None,
device_number: int | None,
) -> str:
tokenizer, model = self._ensure_models()
if device_number is None:
model_index = self._device_number_to_index[0]
else:
model_index = self._device_number_to_index[device_number]

if model_index is None:
raise ValueError(f"Device number {device_number} is not enabled.")

models = self._ensure_models()
tokenizer = models.tokenizer
llm_model = models.llms[model_index]
config = _SIZE_CONFIGS[size]

with self._rwlock.gen_rlock():
config = _SIZE_CONFIGS[size]
# TODO: 支持直接读取图片地址
temp_image_path = os.path.join(temp_path, "temp_image.png")
image.save(temp_image_path)
with InferWithInterruption(model, context) as infer:
with InferWithInterruption(llm_model, context) as infer:
text_result = infer(
tokenizer,
prompt=prompt,
Expand Down Expand Up @@ -133,20 +177,25 @@ def _ensure_models(self) -> _Models:
cache_dir=cache_dir,
local_files_only=self._local_only,
)
model = AutoModel.from_pretrained(
pretrained_model_name_or_path=name_or_path,
_attn_implementation=_ATTN_IMPLEMENTATION,
trust_remote_code=True,
use_safetensors=True,
cache_dir=cache_dir,
local_files_only=self._local_only,
torch_dtype=torch.bfloat16,
device_map="cuda",
)
model = model.eval()
self._models = (tokenizer, model)
preprocess_model(model)
llm_models: list[AutoModel] = []
for device_number, model_index in enumerate(self._device_number_to_index):
if model_index is None:
continue
model = AutoModel.from_pretrained(
pretrained_model_name_or_path=name_or_path,
_attn_implementation=_ATTN_IMPLEMENTATION,
trust_remote_code=True,
use_safetensors=True,
cache_dir=cache_dir,
local_files_only=self._local_only,
)
model = model.cuda(device_number).to(torch.bfloat16)
llm_models.append(preprocess_model(model))

self._models = _Models(
tokenizer=tokenizer,
llms=llm_models,
)
return self._models

def _cache_dir(self) -> str | None:
Expand Down
2 changes: 2 additions & 0 deletions doc_page_extractor/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def extract(
size: DeepSeekOCRSize,
stages: int = 1,
context: ExtractionContext | None = None,
device_number: int | None = None,
) -> Generator[tuple[Image.Image, list[Layout]], None, None]:
...

Expand All @@ -58,5 +59,6 @@ def generate(
temp_path: str,
size: DeepSeekOCRSize,
context: ExtractionContext | None,
device_number: int | None,
) -> str:
...
11 changes: 8 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,23 @@ def main() -> None:
model_path=project_root / "models-cache",
local_only=False,
)
begin_at = time.time()
extractor.load_models()
print(f"Models loaded in {time.time() - begin_at:.2f} seconds.")

plot_dir = project_root / "plot"
plot_dir.mkdir(exist_ok=True)
name_stem = Path(image_name).stem
name_suffix = Path(image_name).suffix
created_at = time.time()
begin_at = time.time()

def check_aborted() -> bool:
if time.time() - created_at > _ABORT_TIMEOUT:
if time.time() - begin_at > _ABORT_TIMEOUT:
print("Aborted extraction due to timeout.")
return True
return False

print("Starting extraction...")
for i, (image, layouts) in enumerate(
extractor.extract(
image=Image.open(image_dir_path / image_name),
Expand All @@ -43,7 +48,7 @@ def check_aborted() -> bool:
output_path = plot_dir / f"{name_stem}_{i}{name_suffix}"
image.save(output_path)

print(f"Extraction cost {time.time() - created_at:.2f} seconds.")
print(f"Extraction cost {time.time() - begin_at:.2f} seconds.")

if __name__ == "__main__":
main()