Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions export/orbax/export/data_processors/jax_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,58 @@
from .third_party.neptune.protos import manifest_pb2


def _jax_spec_from(spec: Any) -> jax.ShapeDtypeStruct | Any:
def _jax_spec_from(spec: Any) -> jax.ShapeDtypeStruct:
"""Converts a ShloTensorSpec to a jax.ShapeDtypeStruct."""
if isinstance(spec, shlo_function.ShloTensorSpec):
if spec.dtype == shlo_function.ShloDType.bf16:
return jax.ShapeDtypeStruct(spec.shape, jax.numpy.bfloat16)
return jax.ShapeDtypeStruct(
spec.shape, shlo_function.shlo_dtype_to_np_dtype(spec.dtype)
)
return spec
if hasattr(spec, 'shape') and hasattr(spec, 'dtype'):
return jax.ShapeDtypeStruct(
shape=tuple(spec.shape),
dtype=spec.dtype,
)
raise ValueError(f'Unsupported spec type: {type(spec)}')


class _JaxShapeSpecGenerator:
"""Generates unique shape spec strings for symbolic_args_specs."""

def __init__(self):
self._counter = 0

def __call__(self, spec: Any) -> str:
# PyTree leaves like int/float don't have a shape attribute.
if hasattr(spec, 'shape'):
if spec.shape is None:
return '...'
try:
shape_list = list(spec.shape)
except (TypeError, ValueError) as e:
raise ValueError(
f'spec.shape must be iterable, got {spec.shape}'
) from e

if not shape_list:
return '()'
dims = []
for i, d in enumerate(shape_list):
if isinstance(d, str):
dims.append(d)
elif d is None:
if i == 0:
dims.append('b')
else:
dims.append(f'd_{self._counter}')
self._counter += 1
else:
dims.append(str(d))
if len(dims) == 1:
return f'({dims[0]},)'
return f"({', '.join(dims)})"
raise ValueError(f'Unsupported spec type: {type(spec)}')


class JaxDataProcessor(data_processor_base.DataProcessor):
Expand Down Expand Up @@ -112,7 +155,10 @@ def prepare(

self._input_signature = input_signature

jax_input_signature = jax.tree.map(_jax_spec_from, self._input_signature)
jax_input_args = jax.tree.map(_jax_spec_from, self._input_signature)
jax_input_shapes_specs = jax.tree.map(
_JaxShapeSpecGenerator(), self._input_signature
)

# Construct args_spec for jax2obm.convert.
# We assume the callable takes (params, inputs) if params is not None,
Expand All @@ -123,7 +169,8 @@ def prepare(
params_spec = jax.tree.map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), self._params
)
args_spec = (params_spec, jax_input_signature)
args = (params_spec, jax_input_args)
shapes_specs = ('...', jax_input_shapes_specs)

ckp_path = self._options.checkpoint_path or 'processor_checkpoint'

Expand All @@ -143,7 +190,10 @@ def _save_checkpoint(

self._save_fn = _save_checkpoint
else:
args_spec = (jax_input_signature,)
args = (jax_input_args,)
shapes_specs = (jax_input_shapes_specs,)

args_spec = jax.export.symbolic_args_specs(args, shapes_specs)

# Instructs the runtime to only load the model parameters from the
# checkpoint, not all keys present in the checkpoint.
Expand Down
117 changes: 117 additions & 0 deletions export/orbax/export/data_processors/jax_data_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from collections.abc import Mapping
import pathlib
import re
from typing import Any

import jax
Expand Down Expand Up @@ -166,6 +167,122 @@ def add(x: jax.Array) -> jax.Array:
('cpu', 'tpu'),
)

def test_prepare_with_polymorphic_shapes(self):
def add(x: jax.Array) -> jax.Array:
return x + 1.0

processor = jax_data_processor.JaxDataProcessor(add, name='add')
processor.prepare(
jax.ShapeDtypeStruct(('b', 3), jnp.float32),
)

self.assertIsNotNone(processor.obm_function)
self.assertIsNotNone(processor.input_signature)
self.assertIsNotNone(processor.output_signature)

def test_prepare_with_polymorphic_shapes_signatures(self):
def add(x: jax.Array) -> jax.Array:
return x + 1.0

processor = jax_data_processor.JaxDataProcessor(add, name='add')
processor.prepare(
jax.ShapeDtypeStruct(('b', 3), jnp.float32),
)

self.assertEqual(
processor.input_signature, jax.ShapeDtypeStruct(('b', 3), jnp.float32)
)

# Note: The underlying OBM Function signature uses None for dynamic
# dimensions, regardless of the symbolic string provided by the user.
out_spec = processor.output_signature
self.assertEqual(list(out_spec.shape), [None, 3]) # pytype: disable=attribute-error

def test_prepare_with_polymorphic_shapes_none(self):
def add(x: jax.Array) -> jax.Array:
return x + 1.0

processor = jax_data_processor.JaxDataProcessor(add, name='add')
processor.prepare(
jax.ShapeDtypeStruct((None, 3), jnp.float32),
)

self.assertEqual(
processor.input_signature, jax.ShapeDtypeStruct((None, 3), jnp.float32)
)

out_spec = processor.output_signature
self.assertEqual(list(out_spec.shape), [None, 3]) # pytype: disable=attribute-error


class JaxShapeSpecGeneratorTest(parameterized.TestCase):

@parameterized.named_parameters(
dict(testcase_name='empty_shape', shape=(), expected='()'),
dict(testcase_name='one_dim_none', shape=(None,), expected='(b,)'),
dict(testcase_name='one_dim_int', shape=(4,), expected='(4,)'),
dict(
testcase_name='multi_dim_first_none',
shape=(None, 4),
expected='(b, 4)',
),
dict(
testcase_name='multi_dim_second_none',
shape=(4, None),
expected='(4, d_0)',
),
dict(
testcase_name='multi_dim_both_none',
shape=(None, None),
expected='(b, d_0)',
),
dict(
testcase_name='multi_dim_all_none',
shape=(None, None, None, 256),
expected='(b, d_0, d_1, 256)',
),
dict(testcase_name='string_dims', shape=('foo', 4), expected='(foo, 4)'),
dict(
testcase_name='string_and_none',
shape=('foo', None),
expected='(foo, d_0)',
),
)
def test_jax_shape_spec_generator(self, expected, shape=None):
spec = jax.ShapeDtypeStruct(shape, jnp.float32)
generator = jax_data_processor._JaxShapeSpecGenerator()
self.assertEqual(generator(spec), expected)

def test_jax_shape_spec_generator_unsupported_type(self):
spec = object()
generator = jax_data_processor._JaxShapeSpecGenerator()
with self.assertRaisesRegex(
ValueError, f'Unsupported spec type: {re.escape(str(type(spec)))}'
):
generator(spec)

def test_jax_shape_spec_generator_multiple_calls(self):
spec1 = jax.ShapeDtypeStruct((None, None), jnp.float32)
spec2 = jax.ShapeDtypeStruct((None, None, 256), jnp.float32)
generator = jax_data_processor._JaxShapeSpecGenerator()
self.assertEqual(generator(spec1), '(b, d_0)')
self.assertEqual(generator(spec2), '(b, d_1, 256)')

def test_jax_shape_spec_generator_none_shape(self):
class NoneShape:
shape = None

generator = jax_data_processor._JaxShapeSpecGenerator()
self.assertEqual(generator(NoneShape()), '...')

def test_jax_shape_spec_generator_uniterable_shape(self):
class UniterableShape:
shape = 4

generator = jax_data_processor._JaxShapeSpecGenerator()
with self.assertRaisesRegex(ValueError, 'spec.shape must be iterable'):
generator(UniterableShape())


if __name__ == '__main__':
googletest.main()
Loading