Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ model.infer_multi(
### SGLang

Set up the environment (uv-managed virtualenv). Install the local SGLang wheel first,
then pin `kernels==0.9.0` and install PyMuPDF for PDF-to-image conversion:
then pin `kernels==0.11.7` and install PyMuPDF for PDF-to-image conversion:
```shell
uv venv --python 3.12
source .venv/bin/activate
Expand Down Expand Up @@ -265,8 +265,16 @@ Useful options:
--model_dir baidu/Unlimited-OCR # Local path or Hugging Face model ID
--gpu 0 # CUDA_VISIBLE_DEVICES value
--server_log ./log/sglang_server.log
--ngram_size 35 # 0 disables the no-repeat-ngram logit processor
--ngram_window 1024 # README recommends 128 for single image, 1024 for multi-page
--resume # skip images/pages whose .md already exists
--results_jsonl ./results.jsonl # structured per-request records
--max_pages 5 # PDF mode: process at most the first N pages
--max_images 5 # image_dir mode: process at most the first N images
```

Note: for `--pdf` mode, `--image_mode` must be `base` (the script enforces this).


## Visualization

Expand Down
219 changes: 177 additions & 42 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@
PROMPT = "document parsing."
TEMPERATURE = 0
CONTEXT_LENGTH = 32768
NO_REPEAT_NGRAM_SIZE = 35
NGRAM_WINDOW = 128
# Defaults follow the README's documented settings:
# single image (gundam): ngram_size=35, ngram_window=128
# multi-page / PDF: ngram_size=35, ngram_window=1024
DEFAULT_NGRAM_SIZE = 35
DEFAULT_NGRAM_WINDOW = 1024
REQUEST_TIMEOUT = 1200
MAX_RETRIES = 5
NO_REPEAT_NGRAM_PROCESSOR_STR = None
Expand All @@ -47,19 +50,27 @@ def get_ngram_processor_str():
return NO_REPEAT_NGRAM_PROCESSOR_STR


def pdf_to_images(pdf_path: str, dpi: int = 300) -> list[str]:
def pdf_to_images(pdf_path: str, dpi: int = 300, max_pages: int | None = None) -> tuple[list[str], tempfile.TemporaryDirectory]:
"""Render PDF pages to PNGs in a temporary directory.

The returned TemporaryDirectory owns the tempdir; callers should keep
it alive for the duration of inference and then let it be garbage-collected
(or call .cleanup()) to release the rendered PNGs.
"""
import fitz

doc = fitz.open(pdf_path)
tmp_dir = tempfile.mkdtemp(prefix="pdf_ocr_")
tmp = tempfile.TemporaryDirectory(prefix="pdf_ocr_")
image_paths = []
mat = fitz.Matrix(dpi / 72, dpi / 72)
for i, page in enumerate(doc):
out_path = os.path.join(tmp_dir, f"page_{i + 1:04d}.png")
page.get_pixmap(matrix=mat).save(out_path)
image_paths.append(out_path)
doc.close()
return image_paths
with fitz.open(pdf_path) as doc:
page_iter = doc
if max_pages is not None:
page_iter = list(doc)[:max_pages]
for i, page in enumerate(page_iter):
out_path = os.path.join(tmp.name, f"page_{i + 1:04d}.png")
page.get_pixmap(matrix=mat).save(out_path)
image_paths.append(out_path)
return image_paths, tmp


def encode_image(image_path: str) -> dict:
Expand Down Expand Up @@ -148,11 +159,17 @@ def stop_server(process):
process._log_file.close()


def collect_stream_silent(resp, output_file: str | None) -> dict:
def collect_stream_silent(resp, output_file: str | None, write_output: bool) -> tuple[int, float, str]:
"""Stream a server response, optionally writing the result to output_file.

Returns (tokens, decode_time, text). When write_output is False (e.g. on
failed requests or when --resume skipped a file), no output file is opened
and any partial file from a previous run is left untouched.
"""
chunks = []
token_count = 0
first_token_time = None
f = open(output_file, "w", encoding="utf-8") if output_file else None
f = open(output_file, "w", encoding="utf-8") if (output_file and write_output) else None
try:
for raw_line in resp.iter_lines():
if not raw_line:
Expand Down Expand Up @@ -182,10 +199,16 @@ def collect_stream_silent(resp, output_file: str | None) -> dict:

end_time = time.time()
decode_time = (end_time - first_token_time) if first_token_time and token_count > 1 else 0
return {"tokens": token_count, "decode_time": decode_time, "text": "".join(chunks)}
return token_count, decode_time, "".join(chunks)


def infer_one(image_path: str, output_file: str | None, args, idx: int) -> dict:
def infer_one(
image_path: str,
output_file: str | None,
args,
idx: int,
write_output: bool = True,
) -> dict:
payload = {
"model": SERVED_MODEL_NAME,
"messages": [{"role": "user", "content": build_content(image_path)}],
Expand All @@ -194,14 +217,16 @@ def infer_one(image_path: str, output_file: str | None, args, idx: int) -> dict:
"stream": True,
"images_config": {"image_mode": args.image_mode},
}
if NO_REPEAT_NGRAM_SIZE > 0 and NGRAM_WINDOW > 0:
if args.ngram_size > 0 and args.ngram_window > 0:
payload["custom_logit_processor"] = get_ngram_processor_str()
payload["custom_params"] = {
"ngram_size": NO_REPEAT_NGRAM_SIZE,
"window_size": NGRAM_WINDOW,
"ngram_size": args.ngram_size,
"window_size": args.ngram_window,
}

name = os.path.basename(image_path)
last_error: str | None = None
result: dict | None = None
for attempt in range(MAX_RETRIES):
try:
resp = requests.post(
Expand All @@ -215,43 +240,63 @@ def infer_one(image_path: str, output_file: str | None, args, idx: int) -> dict:
time.sleep(3 * (attempt + 1))
continue
resp.raise_for_status()
result = collect_stream_silent(resp, output_file)
print(f" [{idx}] {name}: {result['tokens']} tokens, {result['decode_time']:.1f}s")
return result
tokens, decode_time, _text = collect_stream_silent(resp, output_file, write_output)
print(f" [{idx}] {name}: {tokens} tokens, {decode_time:.1f}s")
result = {"status": "ok", "tokens": tokens, "decode_time": decode_time, "error": None}
break
except Exception as e:
last_error = repr(e)
if attempt < MAX_RETRIES - 1:
print(f" [{idx}] {name}: retry {attempt + 1}/{MAX_RETRIES} ({e})")
time.sleep(3 * (attempt + 1))
continue
print(f" [{idx}] {name}: FAILED ({e})")
return {"tokens": 0, "decode_time": 0, "text": ""}
result = {"status": "failed", "tokens": 0, "decode_time": 0, "error": last_error}
if result is None:
result = {"status": "failed", "tokens": 0, "decode_time": 0, "error": "no attempts made"}
return result


def collect_dataset_images(image_dir: str) -> list[str]:
def collect_dataset_images(image_dir: str, max_images: int | None = None) -> list[str]:
exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp")
image_files = []
for root, _, files in os.walk(image_dir):
for name in files:
if name.lower().endswith(exts):
image_files.append(os.path.join(root, name))
return sorted(image_files, key=lambda f: os.path.getsize(f), reverse=True)
image_files.sort()
if max_images is not None:
image_files = image_files[:max_images]
return image_files


def build_jobs(args) -> tuple[list[tuple[str, str | None]], tempfile.TemporaryDirectory | None]:
"""Build (image_path, output_file) jobs.

def build_jobs(args) -> list[tuple[str, str | None]]:
Returns the jobs list and an optional TemporaryDirectory that the caller
must keep alive for the duration of inference (so the rendered PDF
pages stay on disk). It is None for the image_dir mode.
"""
pdf_tmp: tempfile.TemporaryDirectory | None = None
if args.pdf:
image_files = pdf_to_images(args.pdf, dpi=PDF_DPI)
if args.image_mode == "gundam":
raise ValueError(
"--image_mode gundam is not supported with --pdf: multi-page parsing "
"requires --image_mode base. See README."
)
image_files, pdf_tmp = pdf_to_images(args.pdf, dpi=PDF_DPI, max_pages=args.max_pages)
prefix = os.path.splitext(os.path.basename(args.pdf))[0]
jobs = []
for i, image_path in enumerate(image_files):
output_file = None
if args.output_dir:
output_file = os.path.join(args.output_dir, f"{prefix}_page_{i + 1:04d}.md")
jobs.append((image_path, output_file))
return jobs
return jobs, pdf_tmp

if not args.image_dir:
raise ValueError("Either --image_dir or --pdf is required")
image_files = collect_dataset_images(args.image_dir)
image_files = collect_dataset_images(args.image_dir, max_images=args.max_images)

jobs = []
for image_path in image_files:
Expand All @@ -261,44 +306,121 @@ def build_jobs(args) -> list[tuple[str, str | None]]:
stem = os.path.splitext(rel)[0].replace(os.sep, "__")
output_file = os.path.join(args.output_dir, f"{stem}.md")
jobs.append((image_path, output_file))
return jobs
return jobs, pdf_tmp


def _already_done(output_file: str | None) -> bool:
"""Resume helper: True if the output file exists and is non-empty."""
if not output_file:
return False
try:
return os.path.getsize(output_file) > 0
except OSError:
return False


class ResultsWriter:
"""Appends one JSON record per request to a JSONL file."""

def __init__(self, path: str | None):
self.path = path
self._fh = open(path, "w", encoding="utf-8") if path else None

def write(self, record: dict) -> None:
if self._fh is None:
return
self._fh.write(json.dumps(record, ensure_ascii=False) + "\n")
self._fh.flush()

def close(self) -> None:
if self._fh is not None:
self._fh.close()
self._fh = None


def run(args):
jobs = build_jobs(args)
jobs, pdf_tmp = build_jobs(args)
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)

mode = "pdf_pages" if args.pdf else "dataset_images"
print(f"Mode: {mode}, requests={len(jobs)}, concurrency={args.concurrency}, image_mode={args.image_mode}")
print(f"Mode: {mode}, jobs={len(jobs)}, concurrency={args.concurrency}, image_mode={args.image_mode}")

results_writer = ResultsWriter(args.results_jsonl)
wall_start = time.time()
results = []
with ThreadPoolExecutor(max_workers=args.concurrency) as executor:
futures = {
executor.submit(infer_one, image_path, output_file, args, i + 1): image_path
for i, (image_path, output_file) in enumerate(jobs)
}
for future in as_completed(futures):
results.append(future.result())
try:
with ThreadPoolExecutor(max_workers=args.concurrency) as executor:
futures = {}
skipped = 0
for i, (image_path, output_file) in enumerate(jobs):
if args.resume and _already_done(output_file):
skipped += 1
print(f" [{i + 1}] {os.path.basename(image_path)}: skipped (already done)")
results.append({"status": "skipped", "tokens": 0, "decode_time": 0, "error": None})
if results_writer.path:
results_writer.write({
"index": i + 1,
"name": os.path.basename(image_path),
"status": "skipped",
"tokens": 0,
"decode_time_s": 0.0,
"wall_time_s": 0.0,
"output_file": output_file,
"error": None,
})
continue
futures[executor.submit(infer_one, image_path, output_file, args, i + 1, True): (i + 1, image_path, output_file)]

for future in as_completed(futures):
idx, image_path, output_file = futures[future]
request_start = time.time()
r = future.result()
wall = time.time() - request_start
results.append(r)
if results_writer.path:
results_writer.write({
"index": idx,
"name": os.path.basename(image_path),
"status": r["status"],
"tokens": r["tokens"],
"decode_time_s": r["decode_time"],
"wall_time_s": round(wall, 3),
"output_file": output_file,
"error": r["error"],
})
finally:
results_writer.close()
if pdf_tmp is not None:
pdf_tmp.cleanup()

wall_time = time.time() - wall_start
ok_results = [r for r in results if r["status"] == "ok"]
failed_results = [r for r in results if r["status"] == "failed"]
skipped_results = [r for r in results if r["status"] == "skipped"]
total_tokens = sum(r["tokens"] for r in results)
successful = sum(1 for r in results if r["tokens"] > 0)
successful = len(ok_results)

print(f"\n{'=' * 60}")
print("Concurrent Results:")
print(f" Requests: {successful}/{len(jobs)}")
print(f" Requests: ok={successful}, failed={len(failed_results)}, skipped={len(skipped_results)} (total {len(results)})")
print(f" Total tokens: {total_tokens}")
print(f" Wall time: {wall_time:.2f}s")
if wall_time > 0:
print(f" System TPS: {total_tokens / wall_time:.2f} tokens/s")
if successful > 0:
avg_decode = sum(r["decode_time"] for r in results if r["tokens"] > 0) / successful
avg_decode = sum(r["decode_time"] for r in ok_results) / successful
avg_tokens = total_tokens / successful
print(f" Avg tokens/request: {avg_tokens:.0f}")
print(f" Avg decode_time/request: {avg_decode:.2f}s")
if args.results_jsonl:
print(f" Results JSONL: {args.results_jsonl}")
print(f"{'=' * 60}")

# Non-zero exit on full failure so CI / pipelines can detect it.
if jobs and successful == 0 and len(failed_results) > 0:
sys.exit(1)


def parse_args():
parser = argparse.ArgumentParser(
Expand All @@ -311,7 +433,20 @@ def parse_args():
parser.add_argument("--concurrency", type=int, default=8)
parser.add_argument("--gpu", default="0")
parser.add_argument("--model_dir", default="baidu/Unlimited-OCR")
parser.add_argument("--image_mode", choices=("gundam", "base"), default="gundam")
parser.add_argument("--image_mode", choices=("gundam", "base"), default="base",
help="Use 'gundam' for single-image high-res; 'base' is required for PDF / multi-page.")
parser.add_argument("--ngram_size", type=int, default=DEFAULT_NGRAM_SIZE,
help="No-repeat ngram size for the custom logit processor. 0 disables it.")
parser.add_argument("--ngram_window", type=int, default=DEFAULT_NGRAM_WINDOW,
help="Ngram window size. README recommends 128 for single image, 1024 for multi-page.")
parser.add_argument("--resume", action="store_true",
help="Skip images / pages whose .md already exists and is non-empty.")
parser.add_argument("--results_jsonl", default="",
help="If set, write one JSON record per request to this path.")
parser.add_argument("--max_pages", type=int, default=None,
help="(PDF mode) Process at most the first N pages.")
parser.add_argument("--max_images", type=int, default=None,
help="(image_dir mode) Process at most the first N images.")
parser.add_argument("--server_log", default="./log/sglang_server.log")
return parser.parse_args()

Expand Down