Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
"files.trimTrailingWhitespace": true
"python.linting.enabled": true,
"files.trimTrailingWhitespace": true,
},
"cSpell.words": [
"deepseek",
Expand Down
17 changes: 12 additions & 5 deletions doc_page_extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
from .extraction_context import (
AbortError,
ExtractionAbortedError,
ExtractionContext,
TokenLimitError,
)
from .extractor import Layout, PageExtractor
from .model import DeepSeekOCRSize
from .extractor import create_page_extractor
from .plot import plot
from .types import (
Layout,
PageExtractor,
DeepSeekOCRModel,
ExtractionContext,
DeepSeekOCRSize,
)

__version__ = "1.0.0"
__all__ = [
"plot",
"create_page_extractor",
"PageExtractor",
"DeepSeekOCRSize",
"DeepSeekOCRModel",
"ExtractionContext",
"AbortError",
"ExtractionAbortedError",
"TokenLimitError",
"Layout",
"PageExtractor",
"plot",
]
10 changes: 2 additions & 8 deletions doc_page_extractor/extraction_context.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from dataclasses import dataclass
from typing import Any, Callable, cast

import torch
from transformers import StoppingCriteria

from .types import ExtractionContext


@dataclass
class ExtractionContext:
check_aborted: Callable[[], bool]
max_tokens: int | None = None
max_output_tokens: int | None = None
input_tokens: int = 0
output_tokens: int = 0


class ExtractionAbortedError(Exception):
Expand Down
38 changes: 20 additions & 18 deletions doc_page_extractor/extractor.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
import tempfile
from dataclasses import dataclass
from os import PathLike
from pathlib import Path
from typing import Generator, cast

from PIL import Image

from .check_env import check_env
from .extraction_context import ExtractionContext
from .model import DeepSeekOCRModel, DeepSeekOCRSize
from .model import DeepSeekOCRHugginfaceModel
from .parser import ParsedItemKind, parse_ocr_response
from .redacter import background_color, redact
from .types import Layout, PageExtractor, ExtractionContext, DeepSeekOCRModel, DeepSeekOCRSize


@dataclass
class Layout:
ref: str
det: tuple[int, int, int, int]
text: str | None

def create_page_extractor(
model_path: PathLike | None = None,
local_only: bool = False,
) -> PageExtractor:
model: DeepSeekOCRHugginfaceModel = DeepSeekOCRHugginfaceModel(
model_path=Path(model_path) if model_path else None,
local_only=local_only,
)
return _PageExtractorImpls(model)

class PageExtractor:
def __init__(
self,
model_path: PathLike | None = None,
local_only: bool = False,
) -> None:
self._model: DeepSeekOCRModel = DeepSeekOCRModel(
model_path=Path(model_path) if model_path else None,
local_only=local_only,
)
def create_page_extractor_with_model(model: DeepSeekOCRModel) -> PageExtractor:
if not isinstance(model, DeepSeekOCRModel):
raise TypeError("model must implement DeepSeekOCRModel protocol")
return _PageExtractorImpls(model)


class _PageExtractorImpls:
def __init__(self, model: DeepSeekOCRModel) -> None:
self._model: DeepSeekOCRModel = model

def download_models(self) -> None:
self._model.download()
Expand Down
3 changes: 2 additions & 1 deletion doc_page_extractor/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@

from transformers import StoppingCriteria

from .extraction_context import AbortStoppingCriteria, ExtractionContext
from .types import ExtractionContext
from .extraction_context import AbortStoppingCriteria

_LOCAL = threading.local()
_LOCAL_KEY = "value"
Expand Down
6 changes: 3 additions & 3 deletions doc_page_extractor/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Literal
from typing import Any

import torch
from huggingface_hub import snapshot_download
Expand All @@ -12,8 +12,8 @@

from .extraction_context import ExtractionContext
from .injection import InferWithInterruption, preprocess_model
from .types import DeepSeekOCRSize

DeepSeekOCRSize = Literal["tiny", "small", "base", "large", "gundam"]


@dataclass
Expand All @@ -40,7 +40,7 @@ class _SizeConfig:
_Models = tuple[Any, Any]


class DeepSeekOCRModel:
class DeepSeekOCRHugginfaceModel:
def __init__(self, model_path: Path | None, local_only: bool) -> None:
if local_only and model_path is None:
raise ValueError("model_path must be provided when local_only is True")
Expand Down
2 changes: 1 addition & 1 deletion doc_page_extractor/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from PIL.Image import Image
from PIL.ImageFont import FreeTypeFont, load_default

from .extractor import Layout
from .types import Layout

_FRAGMENT_COLOR = (0x49, 0xCF, 0xCB) # Light Green
_Color = tuple[int, int, int]
Expand Down
62 changes: 62 additions & 0 deletions doc_page_extractor/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from typing import Generator, Literal, Callable

from PIL import Image


DeepSeekOCRSize = Literal["tiny", "small", "base", "large", "gundam"]

@dataclass
class Layout:
ref: str
det: tuple[int, int, int, int]
text: str | None

@dataclass
class ExtractionContext:
check_aborted: Callable[[], bool]
max_tokens: int | None = None
max_output_tokens: int | None = None
input_tokens: int = 0
output_tokens: int = 0


@runtime_checkable
class PageExtractor(Protocol):
def download_models(self) -> None:
...

def load_models(self) -> None:
...

def extract(
self,
image: Image.Image,
size: DeepSeekOCRSize,
stages: int = 1,
context: ExtractionContext | None = None,
) -> Generator[tuple[Image.Image, list[Layout]], None, None]:
...


@runtime_checkable
class DeepSeekOCRModel(Protocol):
def download(self) -> None:
...

def load(self) -> None:
...

def unload(self) -> None:
...

def generate(
self,
image: Image.Image,
prompt: str,
temp_path: str,
size: DeepSeekOCRSize,
context: ExtractionContext | None,
) -> str:
...
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from PIL import Image

from doc_page_extractor import ExtractionContext, PageExtractor, plot
from doc_page_extractor import plot, create_page_extractor, ExtractionContext

_ABORT_TIMEOUT = 9999.0 # seconds

Expand All @@ -12,7 +12,7 @@ def main() -> None:
project_root = Path(__file__).parent
image_dir_path = project_root / "tests" / "images"
image_name = "double_column.png"
extractor = PageExtractor(
extractor = create_page_extractor(
model_path=project_root / "models-cache",
local_only=False,
)
Expand Down
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ name = "pytorch-cu121"
url = "https://download.pytorch.org/whl/cu121"
priority = "explicit"

[[tool.poetry.source]]
name = "pytorch-cu118"
url = "https://download.pytorch.org/whl/cu118"
priority = "explicit"

[tool.poetry.group.dev.dependencies]
pylint = "^3.3.7"
# Development environment includes CUDA 12.1 PyTorch
Expand Down