From 9ec01a41a2056358380e91ba2137ddf2368aee6a Mon Sep 17 00:00:00 2001 From: Antoine Zambelli Date: Sat, 20 Jun 2026 00:35:53 -0500 Subject: [PATCH] refactor(clients): extract shared think-tag parsing into forge.prompts.think_tags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The /[THINK] delimiter regex and its extraction helper were duplicated: a full copy (capture groups + _extract_think_tags) in clients/llamafile.py and a second strip-only copy of the regex in prompts/templates.py. Adding reasoning parity to other clients would have meant a third copy. Collapse them into one source of truth, forge/prompts/think_tags.py (THINK_TAG_RE + extract_think_tags): - llamafile.py re-imports it under the historical private name (extract_think_tags as _extract_think_tags) so existing call sites and the test import keep working unchanged. - templates.py imports the shared regex; its only use is .sub("", ...), so the added capture groups are behavior-neutral. Placed under forge.prompts (not forge.clients) because clients/__init__.py eagerly imports every client, which would create an import cycle if templates.py imported a clients-based helper. Pure refactor — no behavior change. 163 unit tests pass (llamafile + templates + vllm). Co-Authored-By: Claude Opus 4.8 (1M context) Claude-Session: https://claude.ai/code/session_01EpuVYCYeb1DhWfynCVyA6a --- src/forge/clients/llamafile.py | 29 +++------------------ src/forge/prompts/templates.py | 7 +---- src/forge/prompts/think_tags.py | 46 +++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 32 deletions(-) create mode 100644 src/forge/prompts/think_tags.py diff --git a/src/forge/clients/llamafile.py b/src/forge/clients/llamafile.py index f34ae53..af8bacd 100644 --- a/src/forge/clients/llamafile.py +++ b/src/forge/clients/llamafile.py @@ -22,15 +22,9 @@ from forge.core.workflow import LLMResponse, TextResponse, ToolCall, ToolSpec from forge.errors import BackendError, ContextDiscoveryError from forge.prompts.templates import build_tool_prompt, extract_tool_call - -# Model-specific thinking tag formats. Extend this list when adding new model -# families. If a model library/registry is added later, move these patterns -# into per-model profiles instead of hard-coding here. -# - [THINK]...[/THINK] — Mistral (Ministral Reasoning) -# - ... — Qwen3, DeepSeek -_THINK_TAG_RE = re.compile( - r"\[THINK\](.*?)\[/THINK\]|(.*?)", re.DOTALL -) +# Re-exported under the historical private name so existing imports +# (`from forge.clients.llamafile import _extract_think_tags`) keep working. +from forge.prompts.think_tags import extract_think_tags as _extract_think_tags # Multi-shard GGUF naming convention: "-00001-of-00003.gguf". The shard # index is filesystem layout, not model identity, so strip it for the @@ -38,23 +32,6 @@ _SHARD_SUFFIX_RE = re.compile(r"-\d{5}-of-\d{5}$") -def _extract_think_tags(text: str) -> tuple[str, str]: - """Extract thinking blocks from text. - - Supports [THINK]...[/THINK] (Mistral) and ... (Qwen/DeepSeek). - Returns (reasoning, remaining_content). - """ - reasoning_parts: list[str] = [] - remaining = text - for m in _THINK_TAG_RE.finditer(text): - # group(1) is [THINK] match, group(2) is match - content = (m.group(1) or m.group(2) or "").strip() - reasoning_parts.append(content) - if reasoning_parts: - remaining = _THINK_TAG_RE.sub("", text).strip() - return "\n\n".join(reasoning_parts), remaining - - def _merge_consecutive(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Ensure strict user/assistant alternation for Jinja parity checker. diff --git a/src/forge/prompts/templates.py b/src/forge/prompts/templates.py index d978dd7..a411a40 100644 --- a/src/forge/prompts/templates.py +++ b/src/forge/prompts/templates.py @@ -6,6 +6,7 @@ import re from forge.core.workflow import ToolCall, ToolSpec +from forge.prompts.think_tags import THINK_TAG_RE as _THINK_TAG_RE def build_tool_prompt(tools: list[ToolSpec]) -> str: @@ -121,12 +122,6 @@ def _try_parse_tool_call(json_str: str, available_tools: list[str]) -> ToolCall r"(\w+)\[ARGS\](\{.*\})", re.DOTALL ) -# Think tag patterns (same as llamafile._THINK_TAG_RE) — needed to strip -# thinking blocks before rescue parsing. -_THINK_TAG_RE = re.compile( - r"\[THINK\].*?\[/THINK\]|.*?", re.DOTALL -) - # Qwen Coder XML tool call format. # # value diff --git a/src/forge/prompts/think_tags.py b/src/forge/prompts/think_tags.py new file mode 100644 index 0000000..bf026cb --- /dev/null +++ b/src/forge/prompts/think_tags.py @@ -0,0 +1,46 @@ +"""Thinking/reasoning tag parsing shared across client adapters. + +Reasoning models wrap their chain-of-thought in delimiter tags. When the +backend's reasoning parser is absent — or doesn't split a given model's output +into a dedicated field — that thinking arrives inline in the message +``content`` instead. This module is the single source of truth for detecting +and extracting those blocks, used by the client adapters (to populate +``ToolCall.reasoning`` and to clean ``TextResponse`` content) and by the +prompt-rescue path in ``templates`` (to strip thinking before parsing a +rehearsed tool call). + +Supported delimiters: + - ``[THINK]...[/THINK]`` — Mistral (Ministral Reasoning) + - ``...`` — Qwen3, DeepSeek + +Extend ``THINK_TAG_RE`` when adding a new model family. If a model +library/registry is added later, move these patterns into per-model profiles +instead of hard-coding here. +""" + +from __future__ import annotations + +import re + +THINK_TAG_RE = re.compile( + r"\[THINK\](.*?)\[/THINK\]|(.*?)", re.DOTALL +) + + +def extract_think_tags(text: str) -> tuple[str, str]: + """Split thinking blocks out of ``text``. + + Returns ``(reasoning, remaining_content)``: the concatenated thinking + blocks (joined by blank lines) and the text with those blocks removed and + stripped. When no tags are present, ``reasoning`` is the empty string and + ``remaining_content`` is the original text unchanged. + """ + reasoning_parts: list[str] = [] + remaining = text + for m in THINK_TAG_RE.finditer(text): + # group(1) is the [THINK] body, group(2) is the body. + content = (m.group(1) or m.group(2) or "").strip() + reasoning_parts.append(content) + if reasoning_parts: + remaining = THINK_TAG_RE.sub("", text).strip() + return "\n\n".join(reasoning_parts), remaining