From 2be10fc57bce3e0db51e9e9dbc1bb9b82b60a6f9 Mon Sep 17 00:00:00 2001 From: Tao Zeyu Date: Fri, 28 Nov 2025 11:40:14 +0800 Subject: [PATCH 1/3] refactor: migrate into types.py --- .vscode/settings.json | 3 +- doc_page_extractor/__init__.py | 9 ++++-- doc_page_extractor/extraction_context.py | 10 ++---- doc_page_extractor/extractor.py | 11 ++----- doc_page_extractor/injection.py | 3 +- doc_page_extractor/model.py | 4 +-- doc_page_extractor/plot.py | 2 +- doc_page_extractor/types.py | 39 ++++++++++++++++++++++++ 8 files changed, 56 insertions(+), 25 deletions(-) create mode 100644 doc_page_extractor/types.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 6a03f57..82183c6 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,7 +4,8 @@ "editor.codeActionsOnSave": { "source.organizeImports": "explicit" }, - "files.trimTrailingWhitespace": true + "python.linting.enabled": true, + "files.trimTrailingWhitespace": true, }, "cSpell.words": [ "deepseek", diff --git a/doc_page_extractor/__init__.py b/doc_page_extractor/__init__.py index 8b97ba4..4fcd487 100644 --- a/doc_page_extractor/__init__.py +++ b/doc_page_extractor/__init__.py @@ -1,12 +1,15 @@ from .extraction_context import ( AbortError, ExtractionAbortedError, - ExtractionContext, TokenLimitError, ) -from .extractor import Layout, PageExtractor -from .model import DeepSeekOCRSize from .plot import plot +from .types import ( + Layout, + PageExtractor, + ExtractionContext, + DeepSeekOCRSize, +) __version__ = "1.0.0" __all__ = [ diff --git a/doc_page_extractor/extraction_context.py b/doc_page_extractor/extraction_context.py index 9c2aec6..73282c7 100644 --- a/doc_page_extractor/extraction_context.py +++ b/doc_page_extractor/extraction_context.py @@ -1,17 +1,11 @@ -from dataclasses import dataclass from typing import Any, Callable, cast import torch from transformers import StoppingCriteria +from .types import ExtractionContext + -@dataclass -class ExtractionContext: - check_aborted: Callable[[], bool] - max_tokens: int | None = None - max_output_tokens: int | None = None - input_tokens: int = 0 - output_tokens: int = 0 class ExtractionAbortedError(Exception): diff --git a/doc_page_extractor/extractor.py b/doc_page_extractor/extractor.py index 838e08a..50340e0 100644 --- a/doc_page_extractor/extractor.py +++ b/doc_page_extractor/extractor.py @@ -1,5 +1,4 @@ import tempfile -from dataclasses import dataclass from os import PathLike from pathlib import Path from typing import Generator, cast @@ -7,18 +6,12 @@ from PIL import Image from .check_env import check_env -from .extraction_context import ExtractionContext -from .model import DeepSeekOCRModel, DeepSeekOCRSize +from .model import DeepSeekOCRModel from .parser import ParsedItemKind, parse_ocr_response from .redacter import background_color, redact +from .types import Layout, ExtractionContext, DeepSeekOCRSize -@dataclass -class Layout: - ref: str - det: tuple[int, int, int, int] - text: str | None - class PageExtractor: def __init__( diff --git a/doc_page_extractor/injection.py b/doc_page_extractor/injection.py index ae8bd6d..f98ebc8 100644 --- a/doc_page_extractor/injection.py +++ b/doc_page_extractor/injection.py @@ -59,7 +59,8 @@ from transformers import StoppingCriteria -from .extraction_context import AbortStoppingCriteria, ExtractionContext +from .types import ExtractionContext +from .extraction_context import AbortStoppingCriteria _LOCAL = threading.local() _LOCAL_KEY = "value" diff --git a/doc_page_extractor/model.py b/doc_page_extractor/model.py index 15abe7c..1e2c29f 100644 --- a/doc_page_extractor/model.py +++ b/doc_page_extractor/model.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from importlib.util import find_spec from pathlib import Path -from typing import Any, Literal +from typing import Any import torch from huggingface_hub import snapshot_download @@ -12,8 +12,8 @@ from .extraction_context import ExtractionContext from .injection import InferWithInterruption, preprocess_model +from .types import DeepSeekOCRSize -DeepSeekOCRSize = Literal["tiny", "small", "base", "large", "gundam"] @dataclass diff --git a/doc_page_extractor/plot.py b/doc_page_extractor/plot.py index 7cf3925..53fce5e 100644 --- a/doc_page_extractor/plot.py +++ b/doc_page_extractor/plot.py @@ -4,7 +4,7 @@ from PIL.Image import Image from PIL.ImageFont import FreeTypeFont, load_default -from .extractor import Layout +from .types import Layout _FRAGMENT_COLOR = (0x49, 0xCF, 0xCB) # Light Green _Color = tuple[int, int, int] diff --git a/doc_page_extractor/types.py b/doc_page_extractor/types.py new file mode 100644 index 0000000..cab0187 --- /dev/null +++ b/doc_page_extractor/types.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import Protocol, runtime_checkable +from typing import Generator, Literal, Callable + +from PIL import Image + + +DeepSeekOCRSize = Literal["tiny", "small", "base", "large", "gundam"] + +@dataclass +class Layout: + ref: str + det: tuple[int, int, int, int] + text: str | None + +@dataclass +class ExtractionContext: + check_aborted: Callable[[], bool] + max_tokens: int | None = None + max_output_tokens: int | None = None + input_tokens: int = 0 + output_tokens: int = 0 + +@runtime_checkable +class PageExtractor(Protocol): + def download_models(self) -> None: + ... + + def load_models(self) -> None: + ... + + def extract( + self, + image: Image.Image, + size: DeepSeekOCRSize, + stages: int = 1, + context: ExtractionContext | None = None, + ) -> Generator[tuple[Image.Image, list[Layout]], None, None]: + ... \ No newline at end of file From 14eba8564c663aaaa6dc29820a72e2e10b0154e2 Mon Sep 17 00:00:00 2001 From: Tao Zeyu Date: Fri, 28 Nov 2025 11:50:36 +0800 Subject: [PATCH 2/3] refactor: enable pass model adapter --- doc_page_extractor/__init__.py | 8 ++++++-- doc_page_extractor/extractor.py | 33 +++++++++++++++++++++------------ doc_page_extractor/model.py | 2 +- doc_page_extractor/types.py | 23 +++++++++++++++++++++++ main.py | 4 ++-- 5 files changed, 53 insertions(+), 17 deletions(-) diff --git a/doc_page_extractor/__init__.py b/doc_page_extractor/__init__.py index 4fcd487..b90d7a9 100644 --- a/doc_page_extractor/__init__.py +++ b/doc_page_extractor/__init__.py @@ -3,22 +3,26 @@ ExtractionAbortedError, TokenLimitError, ) +from .extractor import create_page_extractor from .plot import plot from .types import ( Layout, PageExtractor, + DeepSeekOCRModel, ExtractionContext, DeepSeekOCRSize, ) __version__ = "1.0.0" __all__ = [ + "plot", + "create_page_extractor", + "PageExtractor", "DeepSeekOCRSize", + "DeepSeekOCRModel", "ExtractionContext", "AbortError", "ExtractionAbortedError", "TokenLimitError", "Layout", - "PageExtractor", - "plot", ] diff --git a/doc_page_extractor/extractor.py b/doc_page_extractor/extractor.py index 50340e0..aa678b1 100644 --- a/doc_page_extractor/extractor.py +++ b/doc_page_extractor/extractor.py @@ -6,23 +6,32 @@ from PIL import Image from .check_env import check_env -from .model import DeepSeekOCRModel +from .model import DeepSeekOCRHugginfaceModel from .parser import ParsedItemKind, parse_ocr_response from .redacter import background_color, redact -from .types import Layout, ExtractionContext, DeepSeekOCRSize +from .types import Layout, PageExtractor, ExtractionContext, DeepSeekOCRModel, DeepSeekOCRSize -class PageExtractor: - def __init__( - self, - model_path: PathLike | None = None, - local_only: bool = False, - ) -> None: - self._model: DeepSeekOCRModel = DeepSeekOCRModel( - model_path=Path(model_path) if model_path else None, - local_only=local_only, - ) +def create_page_extractor( + model_path: PathLike | None = None, + local_only: bool = False, +) -> PageExtractor: + model: DeepSeekOCRHugginfaceModel = DeepSeekOCRHugginfaceModel( + model_path=Path(model_path) if model_path else None, + local_only=local_only, + ) + return _PageExtractorImpls(model) + +def create_page_extractor_with_model(model: DeepSeekOCRModel) -> PageExtractor: + if not isinstance(model, DeepSeekOCRModel): + raise TypeError("model must implement DeepSeekOCRModel protocol") + return _PageExtractorImpls(model) + + +class _PageExtractorImpls: + def __init__(self, model: DeepSeekOCRModel) -> None: + self._model: DeepSeekOCRModel = model def download_models(self) -> None: self._model.download() diff --git a/doc_page_extractor/model.py b/doc_page_extractor/model.py index 1e2c29f..346db89 100644 --- a/doc_page_extractor/model.py +++ b/doc_page_extractor/model.py @@ -40,7 +40,7 @@ class _SizeConfig: _Models = tuple[Any, Any] -class DeepSeekOCRModel: +class DeepSeekOCRHugginfaceModel: def __init__(self, model_path: Path | None, local_only: bool) -> None: if local_only and model_path is None: raise ValueError("model_path must be provided when local_only is True") diff --git a/doc_page_extractor/types.py b/doc_page_extractor/types.py index cab0187..8bfd754 100644 --- a/doc_page_extractor/types.py +++ b/doc_page_extractor/types.py @@ -21,6 +21,7 @@ class ExtractionContext: input_tokens: int = 0 output_tokens: int = 0 + @runtime_checkable class PageExtractor(Protocol): def download_models(self) -> None: @@ -36,4 +37,26 @@ def extract( stages: int = 1, context: ExtractionContext | None = None, ) -> Generator[tuple[Image.Image, list[Layout]], None, None]: + ... + + +@runtime_checkable +class DeepSeekOCRModel(Protocol): + def download(self) -> None: + ... + + def load(self) -> None: + ... + + def unload(self) -> None: + ... + + def generate( + self, + image: Image.Image, + prompt: str, + temp_path: str, + size: DeepSeekOCRSize, + context: ExtractionContext | None, + ) -> str: ... \ No newline at end of file diff --git a/main.py b/main.py index 989be4d..a7d98d4 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,7 @@ from PIL import Image -from doc_page_extractor import ExtractionContext, PageExtractor, plot +from doc_page_extractor import plot, create_page_extractor, ExtractionContext _ABORT_TIMEOUT = 9999.0 # seconds @@ -12,7 +12,7 @@ def main() -> None: project_root = Path(__file__).parent image_dir_path = project_root / "tests" / "images" image_name = "double_column.png" - extractor = PageExtractor( + extractor = create_page_extractor( model_path=project_root / "models-cache", local_only=False, ) From 741af9c3c35594830b4f21f32bba1dd52e2ba3ab Mon Sep 17 00:00:00 2001 From: Tao Zeyu Date: Fri, 28 Nov 2025 11:54:07 +0800 Subject: [PATCH 3/3] chore: remove useless source --- poetry.lock | 4 ++-- pyproject.toml | 5 ----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index 7065f9e..c557173 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -1778,4 +1778,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "4ddcd2f8c79cdae6fa29f7cadd12190319975207c909786777025771fc3b020f" +content-hash = "0e38dddce3b877284ec3949162a4b81d9f5ffb6ac7bd032c8f08286b2624a68f" diff --git a/pyproject.toml b/pyproject.toml index a1bc010..22cbd66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,11 +48,6 @@ name = "pytorch-cu121" url = "https://download.pytorch.org/whl/cu121" priority = "explicit" -[[tool.poetry.source]] -name = "pytorch-cu118" -url = "https://download.pytorch.org/whl/cu118" -priority = "explicit" - [tool.poetry.group.dev.dependencies] pylint = "^3.3.7" # Development environment includes CUDA 12.1 PyTorch