Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/forge/clients/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from forge.clients.sampling_defaults import apply_sampling_defaults
from forge.core.workflow import LLMResponse, TextResponse, ToolCall, ToolSpec
from forge.errors import BackendError, ThinkingNotSupportedError
from forge.prompts.think_tags import extract_think_tags

_THINK_HEURISTIC_KEYWORDS = ("reason", "think")

Expand Down Expand Up @@ -120,12 +121,17 @@ def _resolve_reasoning(
) -> str | None:
"""Gate reasoning capture on _think flag.

When _think is False, discard all reasoning.
When True: prefer thinking field, fall back to content.
When _think is False, discard all reasoning. When True: prefer the
structured ``thinking`` field; if absent, extract ``<think>`` tags from
content; finally fall back to the raw content (an instruct model
narrating before its tool call). Mirrors LlamafileClient.
"""
if not self._think:
return None
return thinking or content or None
if thinking:
return thinking
think, _ = extract_think_tags(content)
return think or content or None

def _record_usage(self, data: dict[str, Any]) -> None:
"""Extract token usage from an Ollama response."""
Expand Down Expand Up @@ -214,7 +220,10 @@ async def send(
for i, tc in enumerate(tool_calls)
]

return TextResponse(content=msg.get("content", ""))
# No tool calls: strip inline thinking so the TextResponse carries
# clean content (parity with LlamafileClient).
_, content = extract_think_tags(msg.get("content", ""))
return TextResponse(content=content)

async def send_stream(
self,
Expand Down Expand Up @@ -321,7 +330,8 @@ async def _iter_stream(
content = msg.get("content", "")
if content:
accumulated_content += content
final = TextResponse(content=accumulated_content)
_, text = extract_think_tags(accumulated_content)
final = TextResponse(content=text)
yield StreamChunk(type=ChunkType.FINAL, response=final)
else:
tool_calls = msg.get("tool_calls")
Expand Down
46 changes: 29 additions & 17 deletions src/forge/clients/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from forge.clients.sampling_defaults import apply_sampling_defaults
from forge.core.workflow import LLMResponse, TextResponse, ToolCall, ToolSpec
from forge.errors import BackendError
from forge.prompts.think_tags import extract_think_tags


class VLLMClient:
Expand Down Expand Up @@ -163,23 +164,26 @@ def _record_usage(self, data: dict[str, Any]) -> None:
total_tokens=usage.get("total_tokens", 0),
)

def _resolve_reasoning(
self, message_or_accum: dict[str, Any] | str, accumulated_content: str = "",
) -> str | None:
"""Extract reasoning, gated on _think.

vLLM 0.21 returns reasoning in the ``reasoning`` field of the
assistant message when ``--reasoning-parser`` is enabled at server
boot. If thinking is disabled or the field is empty, return None.

Accepts either a message dict (from send()) or an accumulated
reasoning string (from send_stream()).
def _resolve_reasoning(self, reasoning: str, content: str) -> str | None:
"""Build final reasoning from the structured field and content, gated
on _think.

vLLM 0.21 returns reasoning in the ``reasoning`` field of the assistant
message when ``--reasoning-parser`` is enabled at server boot. When
that parser is absent — or doesn't split a given model's output — the
thinking instead arrives inline in ``content`` (often wrapped in
``<think>...</think>``). To avoid silently dropping it (issue #110) and
to keep send() and send_stream() in lockstep with LlamafileClient, fall
back to ``<think>``-tag extraction and then to the raw content when the
structured field is empty. Both call sites pass the same (reasoning,
content) pair, so the two paths resolve identically.
"""
if not self._think:
return None
if isinstance(message_or_accum, dict):
return message_or_accum.get("reasoning") or None
return message_or_accum or accumulated_content or None
if reasoning:
return reasoning
think, _ = extract_think_tags(content)
return think or content or None

async def send(
self,
Expand Down Expand Up @@ -227,10 +231,16 @@ async def send(
if tool_calls:
return self._parse_tool_calls(
tool_calls,
reasoning=self._resolve_reasoning(message),
reasoning=self._resolve_reasoning(
message.get("reasoning") or "", message.get("content") or "",
),
)

return TextResponse(content=message.get("content") or "")
# No tool calls: strip any inline thinking — reasoning is only useful
# attached to a ToolCall; a TextResponse carries clean content (parity
# with LlamafileClient.send()).
_, content = extract_think_tags(message.get("content") or "")
return TextResponse(content=content)

async def send_stream(
self,
Expand Down Expand Up @@ -334,7 +344,9 @@ async def send_stream(
),
)
else:
final = TextResponse(content=accumulated_content)
# Strip inline thinking from the final text for parity with send().
_, text = extract_think_tags(accumulated_content)
final = TextResponse(content=text)
yield StreamChunk(type=ChunkType.FINAL, response=final)

@staticmethod
Expand Down
81 changes: 81 additions & 0 deletions tests/unit/test_ollama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,41 @@ async def test_think_false_discards_reasoning(self) -> None:
assert isinstance(result, list)
assert result[0].reasoning is None

@pytest.mark.asyncio
async def test_extracts_think_tags_from_content_with_tool_call(self) -> None:
"""<think> tags inline in content are extracted (not the raw tagged
string), when there is no structured thinking field."""
client = _make_client(think=True)
client._http.post.return_value = _mock_response({
"message": {
"role": "assistant",
"content": "<think>price first</think>",
"tool_calls": [
{"function": {"name": "get_pricing", "arguments": {"part": "X"}}}
],
}
})
result = await client.send(
[{"role": "user", "content": "test"}], tools=[_make_spec()]
)
assert isinstance(result, list)
assert result[0].reasoning == "price first"

@pytest.mark.asyncio
async def test_think_tags_stripped_from_text_response(self) -> None:
"""A bare text reply has <think> tags stripped from its content."""
client = _make_client()
client._http.post.return_value = _mock_response({
"message": {
"role": "assistant",
"content": "<think>pondering</think>Hello there.",
"tool_calls": [],
}
})
result = await client.send([{"role": "user", "content": "test"}])
assert isinstance(result, TextResponse)
assert result.content == "Hello there."

@pytest.mark.asyncio
async def test_think_true_explicit(self) -> None:
"""think=True explicitly → always in request body."""
Expand Down Expand Up @@ -499,6 +534,52 @@ async def test_streaming_captures_reasoning_from_deltas(self) -> None:
assert isinstance(final.response, list)
assert final.response[0].reasoning == "Let me think..."

@pytest.mark.asyncio
async def test_streaming_extracts_think_tags_from_content_with_tool_call(self) -> None:
"""#110 (streaming): inline <think> in streamed content (no thinking
deltas) is extracted onto the FINAL tool call."""
client = _make_client(think=True)
lines = [
json.dumps({"message": {"role": "assistant", "content": "<think>price "}, "done": False}),
json.dumps({"message": {"role": "assistant", "content": "first</think>"}, "done": False}),
json.dumps({
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{"function": {"name": "get_pricing", "arguments": {"part": "X"}}}
],
},
"done": True,
}),
]
client._http.stream.return_value = _MockStreamResponse(lines)
chunks = []
async for chunk in client.send_stream(
[{"role": "user", "content": "test"}], tools=[_make_spec()]
):
chunks.append(chunk)
final = [c for c in chunks if c.type == ChunkType.FINAL][0]
assert isinstance(final.response, list)
assert final.response[0].reasoning == "price first"

@pytest.mark.asyncio
async def test_streaming_strips_think_tags_from_text_response(self) -> None:
"""A streamed bare text reply has <think> tags stripped from FINAL."""
client = _make_client()
lines = [
json.dumps({"message": {"role": "assistant", "content": "<think>pondering</think>"}, "done": False}),
json.dumps({"message": {"role": "assistant", "content": "Hello there."}, "done": False}),
json.dumps({"message": {"role": "assistant", "content": ""}, "done": True}),
]
client._http.stream.return_value = _MockStreamResponse(lines)
chunks = []
async for chunk in client.send_stream([{"role": "user", "content": "test"}]):
chunks.append(chunk)
final = [c for c in chunks if c.type == ChunkType.FINAL][0]
assert isinstance(final.response, TextResponse)
assert final.response.content == "Hello there."

@pytest.mark.asyncio
async def test_streaming_thinking_preferred_over_content(self) -> None:
"""Streamed thinking tokens are preferred over content for reasoning."""
Expand Down
93 changes: 93 additions & 0 deletions tests/unit/test_vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,58 @@ async def test_think_false_discards_reasoning(self) -> None:
assert isinstance(result, list)
assert result[0].reasoning is None

@pytest.mark.asyncio
async def test_extracts_think_tags_from_content_with_tool_call(self) -> None:
"""#110: thinking inline in content (no `reasoning` field) is captured."""
client = _make_client(think=True)
client._http.post.return_value = _mock_response(
_tool_call_response(content="<think>check the weather first</think>"),
)
result = await client.send(
[{"role": "user", "content": "x"}], tools=[_make_spec()],
)
assert isinstance(result, list)
assert result[0].reasoning == "check the weather first"

@pytest.mark.asyncio
async def test_reasoning_field_preferred_over_content_tags(self) -> None:
"""Structured `reasoning` field wins over <think> tags in content."""
client = _make_client(think=True)
client._http.post.return_value = _mock_response(
_tool_call_response(reasoning="structured", content="<think>inline</think>"),
)
result = await client.send(
[{"role": "user", "content": "x"}], tools=[_make_spec()],
)
assert isinstance(result, list)
assert result[0].reasoning == "structured"

@pytest.mark.asyncio
async def test_think_tags_stripped_from_text_response(self) -> None:
"""A bare text reply has <think> tags stripped from its content."""
client = _make_client()
client._http.post.return_value = _mock_response(
_text_response("<think>pondering</think>The answer is 42."),
)
result = await client.send([{"role": "user", "content": "x"}])
assert isinstance(result, TextResponse)
assert result.content == "The answer is 42."

@pytest.mark.asyncio
async def test_thinking_only_text_response_empty_after_strip(self) -> None:
"""Thinking-only reply (no answer, no tool call) strips to empty content.

Matches LlamafileClient; the empty TextResponse then rides the existing
ResponseValidator retry path (covered in the validator tests).
"""
client = _make_client()
client._http.post.return_value = _mock_response(
_text_response("<think>just thinking, no answer yet</think>"),
)
result = await client.send([{"role": "user", "content": "x"}])
assert isinstance(result, TextResponse)
assert result.content == ""

@pytest.mark.asyncio
async def test_usage_recorded(self) -> None:
client = _make_client()
Expand Down Expand Up @@ -381,6 +433,47 @@ async def test_accumulates_reasoning_across_deltas(self) -> None:
assert isinstance(result, list)
assert result[0].reasoning == "Let me think... "

@pytest.mark.asyncio
async def test_stream_extracts_think_tags_from_content_with_tool_call(self) -> None:
"""#110 (streaming): inline <think> in streamed content (no reasoning
deltas) is captured on the FINAL tool call."""
client = _make_client(think=True)
client._http.stream.return_value = _MockStreamResponse([
_sse({"choices": [{"delta": {"content": "<think>inline "}}]}),
_sse({"choices": [{"delta": {"content": "plan</think>"}}]}),
_sse({"choices": [{"delta": {
"tool_calls": [{
"index": 0,
"function": {"name": "get_weather", "arguments": '{"city": "P"}'}
}],
}}]}),
"data: [DONE]",
])
chunks = []
async for chunk in client.send_stream(
[{"role": "user", "content": "x"}], tools=[_make_spec()],
):
chunks.append(chunk)
result = [c for c in chunks if c.type == ChunkType.FINAL][0].response
assert isinstance(result, list)
assert result[0].reasoning == "inline plan"

@pytest.mark.asyncio
async def test_stream_strips_think_tags_from_text_response(self) -> None:
"""A streamed bare text reply has <think> tags stripped from FINAL."""
client = _make_client()
client._http.stream.return_value = _MockStreamResponse([
_sse({"choices": [{"delta": {"content": "<think>pondering</think>"}}]}),
_sse({"choices": [{"delta": {"content": "The answer is 42."}}]}),
"data: [DONE]",
])
chunks = []
async for chunk in client.send_stream([{"role": "user", "content": "x"}]):
chunks.append(chunk)
result = [c for c in chunks if c.type == ChunkType.FINAL][0].response
assert isinstance(result, TextResponse)
assert result.content == "The answer is 42."

@pytest.mark.asyncio
async def test_non_200_raises_backend_error(self) -> None:
client = _make_client()
Expand Down
Loading