⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content

Conversation

@manuelcandales
Copy link
Contributor

@manuelcandales manuelcandales commented Jan 13, 2026

This pull request builds on top of #16499 to introduce support for the Parakeet model in the Metal backend. The most important changes are grouped below:

Parakeet export/lowering:

  • Added support for Metal lowering to the Parakeet export/lowering script
  • Provided custom linear decomposition. This achieved two objectives: avoided call to addmm, and avoided call to reinterpret_tensor_wrapper with 0 stride

Operator updates:

  • Added implementation for aoti_torch_mps_bmm_out to support batched matrix multiplication (bmm) in the Metal backend
  • Fixed input channel dimension handling for grouped convolutions in aoti_torch_mps_convolution by reading the correct dimension from the weight tensor.

Shim layer updates:

  • Added implementation for aoti_torch_new_tensor_handle
  • Enabled non-zero tensor storage offsets in aoti_torch__reinterpret_tensor by adjusting the data pointer instead of rejecting non-zero offsets, and updating memory tracking and Metal buffer mapping logic accordingly.
  • Added the metal_buffer_nocopy function to map arbitrary memory pointers into Metal buffers, supporting cases where the data pointer is offset.
  • Improved error messages in several stubbed shim functions by including the function name in the exception message for easier debugging.

Copilot AI review requested due to automatic review settings January 13, 2026 16:56
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 13, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16562

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 101 Pending

As of commit a9b0d95 with merge base 2c59f85 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 13, 2026
@manuelcandales manuelcandales added the release notes: none Do not include this in the release notes label Jan 13, 2026
Copilot AI review requested due to automatic review settings January 13, 2026 17:23
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables support for the Parakeet model in the Metal backend by implementing the necessary operators, fixing existing issues, and adding export/lowering infrastructure. The changes build upon existing Metal backend functionality to support this specific ASR model.

Changes:

  • Added Metal backend support to Parakeet export script with custom linear decomposition to avoid addmm and handle tensor reinterpretation
  • Implemented aoti_torch_mps_bmm_out for batched matrix multiplication and fixed grouped convolution input channel handling
  • Enhanced shim layer to support non-zero tensor storage offsets and new tensor handle creation

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
examples/models/parakeet/export_parakeet_tdt.py Added Metal backend support with custom decomposition and dtype handling
examples/models/parakeet/README.md Updated documentation for Metal export and runner usage
backends/apple/metal/runtime/shims/memory.cpp Implemented storage offset support and new tensor handle creation
backends/apple/metal/runtime/shims/et_metal_ops.mm Added bmm_out implementation and fixed grouped convolution
backends/apple/metal/runtime/shims/et_metal_ops.h Added bmm_out function declaration
backends/apple/metal/runtime/shims/et_metal.mm Added metal_buffer_nocopy function
backends/apple/metal/runtime/shims/et_metal.h Added metal_buffer_nocopy declaration
backends/apple/metal/metal_backend.py Updated supported fallback kernels list
backends/aoti/common_shims.cpp Improved error messages for stubbed functions

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

return partitioner, programs


def _linear_bias_decomposition(input, weight, bias=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we naturally decompose this, if its not being decomposed its because the aoti backend probably specifies not to decompose it. So we should probably just disable that for metal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your comment made me realize I should add a comment explaining this. This decomposition is the key to making things work. Linear does get decomposed, but not in a way that works for us. When bias is present, and tensors are 2D, this gets decomposed into addmm. This requires us to implement addmm in the Metal backend, but more importantly, reinterpret_tensor_wrapper gets called on the bias, to make it look like a 2D. That eventually makes its way to ExecuTorch, as a call to executorch::extension::from_blob with a 0 stride (to view a 1D tensor as a 2D tensor). ExecuTorch doesn't currently support that, and raises and error.
This decomposition avoids that problem, and also avoids having to implement addmm.

@autoreleasepool {
try {
// Convert AOTITensorHandle to ExecutorTorch tensors
auto out_tensor = reinterpret_cast<Tensor*>(out);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It would be useful to wrap this cast in a helper function since we plan on changing the tensor definition backing the shim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be better suited for a separate PR, since all other ops do it this way. Maybe just leave for when I do the migration to SlimTensor.


// Create cache key for this batched matrix multiplication
// Cache key includes: op_name, shape params {B, M, K, N}, dtype, transpose_flag
// This allows reuse when same BMM shape/dtype is called repeatedly
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this happen often with transformers? I would think the shape mostly changes each iteration with growing context length

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am seeing bmm being called on encode, not on decode. I am seeing 72 calls to bmm on the inductor-generated code for encoding, with what seems to be 2 different shape combinations. So, this caching would be avoiding the redundant creation of 70 MPSGraphs.

(void)device; // Used for validation, consistent with other ops

// Get Metal buffers for input and output tensors
id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_bmm_out", "self");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just sort of curious why do you pass the fn name and arg name, are these keys to a dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for logging

ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal device");
return Error::Internal;
}
(void)device; // Used for validation, consistent with other ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats going on here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary call to get_metal_device, cleaned up


// Validate tensor dimensions - bmm requires 3-D tensors
if (self_tensor->dim() != 3 || mat2_tensor->dim() != 3 || out_tensor->dim() != 3) {
ET_LOG(Error, "aoti_torch_mps_bmm_out: tensors must be 3-D. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the failure is that dims is wrong why is it useful to print the shape. Especially since dims could be too big?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned up (originally introduced while debugging)

Copy link
Contributor

@mergennachin mergennachin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See inline comments

(void)new_handle;
throw std::runtime_error("Not implemented");
return Error::Internal;
ET_LOG(Debug, "aoti_torch_new_tensor_handle: entered");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you handle both zero and non-zero offset in this method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean? aoti_torch_new_tensor_handle doesn't take an offset

Copilot AI review requested due to automatic review settings January 14, 2026 17:38
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@mergennachin mergennachin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@manuelcandales manuelcandales merged commit 27b778d into main Jan 14, 2026
131 of 132 checks passed
@manuelcandales manuelcandales deleted the manuel/parakeet-metal branch January 14, 2026 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: none Do not include this in the release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants