diff --git a/optimum/exporters/executorch/recipes/metal.py b/optimum/exporters/executorch/recipes/metal.py index 0ca1c56..cf7fe8b 100644 --- a/optimum/exporters/executorch/recipes/metal.py +++ b/optimum/exporters/executorch/recipes/metal.py @@ -31,6 +31,8 @@ METAL_BACKEND_AVAILABLE = False if METAL_BACKEND_AVAILABLE: + import torch + from tabulate import tabulate from torch.export import ExportedProgram @@ -52,6 +54,13 @@ ) from ..recipe_registry import register_recipe + def _linear_bias_decomposition(input, weight, bias=None): + weight_t = torch.ops.aten.t.default(weight) + out = torch.ops.aten.matmul.default(input, weight_t) + if bias is not None: + return torch.ops.aten.add.Tensor(out, bias) + return out + @register_recipe("metal") def export_to_executorch_with_metal( model: Union[ @@ -89,6 +98,13 @@ def _lower_to_executorch( if len(exported_programs) == 1: exported_programs = {"forward": next(iter(exported_programs.values()))} + # Decompose linear+bias into matmul+add to avoid addmm, + # which the Metal backend doesn't support. + for key in exported_programs: + decomp_table = torch.export.default_decompositions() + decomp_table[torch.ops.aten.linear.default] = _linear_bias_decomposition + exported_programs[key] = exported_programs[key].run_decompositions(decomp_table) + partitioners = { key: [MetalPartitioner([MetalBackend.generate_method_name_compile_spec(key)])] for key in exported_programs.keys()