Skip to content
Closed
2 changes: 2 additions & 0 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,7 @@ class used during generation, if any.
generate_log.result = mot

mot._generate_log = generate_log
mot._format = _format

async def _generate_from_raw(
self,
Expand Down Expand Up @@ -1513,6 +1514,7 @@ async def _generate_from_raw(
generate_log.action = action

result._generate_log = generate_log
result._format = format
Comment thread
planetf1 marked this conversation as resolved.
results.append(result)

usage: dict[str, Any] | None = (
Expand Down
1 change: 1 addition & 0 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ async def post_processing(
generate_log.action = mot._action
generate_log.result = mot
mot._generate_log = generate_log
mot._format = _format

# Extract token usage from full response dict or streaming usage
full_response = mot._meta.get("litellm_full_response")
Expand Down
2 changes: 2 additions & 0 deletions mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ async def _generate_from_raw(
generate_log.extra["error"] = error
generate_log.extra["empty_response"] = response.model_dump()
result._generate_log = generate_log
result._format = format

results.append(result)

Expand Down Expand Up @@ -742,6 +743,7 @@ async def post_processing(
generate_log.result = mot

mot._generate_log = generate_log
mot._format = _format
mot._generate = None

# Extract token counts from response
Expand Down
1 change: 1 addition & 0 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,7 @@ async def post_processing(
generate_log.action = mot._action
generate_log.result = mot
mot._generate_log = generate_log
mot._format = _format

# Extract token usage from response or streaming usage
response = mot._meta["oai_chat_response"]
Expand Down
1 change: 1 addition & 0 deletions mellea/backends/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ async def post_processing(
generate_log.result = mot
generate_log.action = mot._action
mot._generate_log = generate_log
mot._format = _format

async def _generate_from_raw(
self,
Expand Down
59 changes: 57 additions & 2 deletions mellea/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
ParamSpec,
Protocol,
TypeVar,
cast,
runtime_checkable,
)

import pydantic
import typing_extensions
from PIL import Image as PILImage

Expand Down Expand Up @@ -401,6 +403,7 @@ def __init__(
# Mellea-side hook correlation ID; distinct from the provider-assigned
# `GenerationMetadata.response_id`.
self._generation_id: str | None = None
self._format: type[pydantic.BaseModel] | None = None
Comment thread
planetf1 marked this conversation as resolved.

def _record_ttfb(self) -> None:
"""Record time-to-first-byte if streaming and not yet recorded."""
Expand Down Expand Up @@ -542,6 +545,7 @@ def _copy_from(self, other: ModelOutputThunk) -> None:
self._thinking = other._thinking
self.generation = other.generation
self._generate_log = other._generate_log
self._format = other._format
self._cancelled = other._cancelled
# _cancel_hook is deliberately not copied: _copy_from swaps output state,
# not backend-thread plumbing, which is tied to the original computation.
Expand All @@ -557,7 +561,13 @@ def is_computed(self) -> bool:

@property
def value(self) -> str | None:
"""Gets the value of the block."""
"""Gets the raw string value of the block.

When ``format=`` is set on the originating ``act()``/``instruct()`` call, the
model returns a JSON string and ``.value`` contains that raw JSON — not a
Pydantic instance. Use ``.parsed`` on a ``ComputedModelOutputThunk`` to get
the validated model object.
"""
if not self._computed:
return None
return self._underlying_value
Expand Down Expand Up @@ -776,6 +786,7 @@ def __copy__(self) -> ModelOutputThunk:
copied._action = self._action
copied._context = self._context
copied._generate_log = self._generate_log
copied._format = self._format
copied._model_options = self._model_options
copied.generation = copy(self.generation)
return copied
Expand Down Expand Up @@ -810,6 +821,7 @@ def __deepcopy__(self, memo: dict) -> ModelOutputThunk:
self._context
) # The items in a context should be immutable.
deepcopied._generate_log = copy(self._generate_log)
deepcopied._format = self._format
deepcopied._model_options = copy(self._model_options)
deepcopied.generation = deepcopy(self.generation)
return deepcopied
Expand Down Expand Up @@ -873,14 +885,57 @@ async def astream(self) -> str:

@property
def value(self) -> str:
"""Gets the value of the block."""
"""Gets the raw string value of the block.

When ``format=`` is set on the originating ``act()``/``instruct()`` call, the
model returns a JSON string and ``.value`` contains that raw JSON — not a
Pydantic instance. Use ``.parsed`` to get the validated model object.
"""
return self._underlying_value # type: ignore

@value.setter
def value(self, v: str):
"""Sets the value of the block."""
self._underlying_value = v

@property
def parsed(self) -> S | None:
"""Returns the result as a validated Pydantic instance when ``format=`` was set.

The return type is ``S | None``, where ``S`` is the thunk's type parameter.
The ``format=`` overloads do not yet bind ``S`` to the format model, so
callers must parameterize the thunk explicitly to get a narrowed type::

thunk = cast(ComputedModelOutputThunk[MyModel], result)
obj = thunk.parsed # typed MyModel | None, no model_validate_json needed

Returns ``None`` when no ``format=`` type was provided. Unlike
``parsed_repr`` (which holds the action-specific parse result),
``.parsed`` always re-validates the raw JSON string against ``_format``
via ``model_validate_json``.

Note:
This property relies on the originating backend storing the format
type on the thunk. Custom backend authors must set ``mot._format``
in their ``post_processing`` method (mirroring the built-in
backends); otherwise ``.parsed`` always returns ``None`` even when
``format=`` was supplied.

Returns:
An instance of the format type (``S``) produced by
``model_validate_json``, or ``None`` if no format type was set.

Raises:
pydantic.ValidationError: If the raw JSON value does not conform to
the format model (e.g. the model returned malformed structured output).
"""
if self._format is None:
Comment thread
planetf1 marked this conversation as resolved.
return None
# `_format` is always a pydantic model type; `model_validate_json` returns
# `pydantic.BaseModel` statically, but the caller's type parameter `S` is
# the concrete model when `format=` was used, so we cast the result to `S`.
return cast(S, self._format.model_validate_json(self.value))

def is_computed(self) -> Literal[True]:
"""Returns `True` since thunk is always computed.

Expand Down
4 changes: 3 additions & 1 deletion mellea/stdlib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,9 @@ def act(
requirements: used as additional requirements when a sampling strategy is provided
strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used.
return_sampling_results: attach the (successful and failed) sampling attempts to the results.
format: if set, the BaseModel to use for constrained decoding.
format: if set, the BaseModel to use for constrained decoding. When
provided, ``.value`` on the returned thunk is always a raw JSON string —
use ``.parsed`` to obtain the validated Pydantic model instance.
model_options: additional model options, which will upsert into the model/backend's defaults.
tool_calls: if true, tool calling is enabled.

Expand Down
18 changes: 18 additions & 0 deletions test/backends/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,24 @@ class Email(pydantic.BaseModel):
)


@pytest.mark.qualitative
def test_parsed_returns_pydantic_instance(session) -> None:
class Sentiment(pydantic.BaseModel):
label: str

output = session.instruct(
"Classify the sentiment of 'I love this!' as the single word "
"positive, negative, or neutral. Respond with a label field.",
format=Sentiment,
model_options={ModelOption.MAX_NEW_TOKENS: 2**8},
)

parsed = output.parsed
assert isinstance(parsed, Sentiment)
assert isinstance(parsed.label, str) and parsed.label
assert parsed == Sentiment.model_validate_json(output.value)


@pytest.mark.qualitative
async def test_generate_from_raw(session) -> None:
prompts = [
Expand Down
18 changes: 18 additions & 0 deletions test/backends/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,24 @@ class Email(pydantic.BaseModel):
# assert email.to.email_address.endswith("example.com")


@pytest.mark.qualitative
def test_parsed_returns_pydantic_instance(session) -> None:
class Sentiment(pydantic.BaseModel):
label: str

output = session.instruct(
"Classify the sentiment of 'I love this!' as the single word "
"positive, negative, or neutral. Respond with a label field.",
format=Sentiment,
model_options={ModelOption.MAX_NEW_TOKENS: 2**8},
)

parsed = output.parsed
assert isinstance(parsed, Sentiment)
assert isinstance(parsed.label, str) and parsed.label
assert parsed == Sentiment.model_validate_json(output.value)


@pytest.mark.qualitative
@pytest.mark.timeout(150)
async def test_generate_from_raw(session) -> None:
Expand Down
69 changes: 68 additions & 1 deletion test/core/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
import io
from typing import Any

import pydantic
import pytest
from PIL import Image as PILImage

from mellea.core import CBlock, Component, ImageBlock, ModelOutputThunk
from mellea.core import (
CBlock,
Component,
ComputedModelOutputThunk,
ImageBlock,
ModelOutputThunk,
)
from mellea.stdlib.components import Message


Expand Down Expand Up @@ -317,3 +324,63 @@ async def _absorbs_first_cancel() -> None:
await asyncio.wait_for(mot._generate, timeout=1.0) # type: ignore[attr-defined]
except (TimeoutError, asyncio.CancelledError):
pass


# --- ComputedModelOutputThunk.parsed ---


class _Label(pydantic.BaseModel):
label: str


def _make_computed(
json_str: str, fmt: type[pydantic.BaseModel] | None
) -> ComputedModelOutputThunk:
thunk = ModelOutputThunk(value=json_str)
thunk._format = fmt
return ComputedModelOutputThunk(thunk)
Comment thread
planetf1 marked this conversation as resolved.


def test_parsed_returns_model_instance() -> None:
result = _make_computed('{"label": "yes"}', _Label)
obj = result.parsed
assert isinstance(obj, _Label)
assert obj.label == "yes"


def test_parsed_returns_none_when_no_format() -> None:
result = _make_computed('{"label": "yes"}', None)
assert result.parsed is None


def test_parsed_raises_on_invalid_json() -> None:
result = _make_computed("not json", _Label)
with pytest.raises(pydantic.ValidationError):
_ = result.parsed


def test_value_unaffected_by_format() -> None:
raw = '{"label": "ok"}'
result = _make_computed(raw, _Label)
assert result.value == raw


def test_format_preserved_by_copy() -> None:
import copy as _copy

result = _make_computed('{"label": "yes"}', _Label)
shallow = _copy.copy(result)
assert shallow._format is _Label
# __copy__ returns ModelOutputThunk (loses ComputedModelOutputThunk subclass due to
# zero-copy __class__ reassignment), so we validate manually rather than via .parsed.
assert shallow._format.model_validate_json(shallow.value).label == "yes" # type: ignore[union-attr]


def test_format_preserved_by_deepcopy() -> None:
import copy as _copy

result = _make_computed('{"label": "yes"}', _Label)
deep = _copy.deepcopy(result)
assert deep._format is _Label
# Same subclass-loss caveat as test_format_preserved_by_copy.
assert deep._format.model_validate_json(deep.value).label == "yes" # type: ignore[union-attr]
32 changes: 32 additions & 0 deletions test/typing/check_parsed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Mypy / pyright checks that `ComputedModelOutputThunk.parsed` tracks the type parameter.

`.parsed` is typed `S | None`, so a thunk parameterized with a Pydantic model
(`ComputedModelOutputThunk[MyModel]`) exposes `.parsed` as `MyModel | None` —
callers need no `cast()`. The `format=` overloads in `session.py` /
`functional.py` bind `S` to the format model (companion issue #1274), at which
point these checks hold end-to-end from the call site.

The `cast(X, cast(object, None))` idiom creates a typed stub value without
triggering basedpyright's ``reportInvalidCast`` rule (direct ``cast(X, None)``
raises that diagnostic because ``None`` and ``X`` share no overlap).
"""

from typing import assert_type, cast

import pydantic

from mellea.core import ComputedModelOutputThunk


class _Person(pydantic.BaseModel):
name: str


def check_parsed_tracks_format_model() -> None:
thunk = cast(ComputedModelOutputThunk[_Person], cast(object, None))
_ = assert_type(thunk.parsed, _Person | None)


def check_parsed_is_str_for_str_thunk() -> None:
thunk = cast(ComputedModelOutputThunk[str], cast(object, None))
_ = assert_type(thunk.parsed, str | None)
Loading