-
Notifications
You must be signed in to change notification settings - Fork 797
Metal backend: enable Parakeet #16562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16562
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 101 PendingAs of commit a9b0d95 with merge base 2c59f85 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enables support for the Parakeet model in the Metal backend by implementing the necessary operators, fixing existing issues, and adding export/lowering infrastructure. The changes build upon existing Metal backend functionality to support this specific ASR model.
Changes:
- Added Metal backend support to Parakeet export script with custom linear decomposition to avoid addmm and handle tensor reinterpretation
- Implemented
aoti_torch_mps_bmm_outfor batched matrix multiplication and fixed grouped convolution input channel handling - Enhanced shim layer to support non-zero tensor storage offsets and new tensor handle creation
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/models/parakeet/export_parakeet_tdt.py | Added Metal backend support with custom decomposition and dtype handling |
| examples/models/parakeet/README.md | Updated documentation for Metal export and runner usage |
| backends/apple/metal/runtime/shims/memory.cpp | Implemented storage offset support and new tensor handle creation |
| backends/apple/metal/runtime/shims/et_metal_ops.mm | Added bmm_out implementation and fixed grouped convolution |
| backends/apple/metal/runtime/shims/et_metal_ops.h | Added bmm_out function declaration |
| backends/apple/metal/runtime/shims/et_metal.mm | Added metal_buffer_nocopy function |
| backends/apple/metal/runtime/shims/et_metal.h | Added metal_buffer_nocopy declaration |
| backends/apple/metal/metal_backend.py | Updated supported fallback kernels list |
| backends/aoti/common_shims.cpp | Improved error messages for stubbed functions |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| return partitioner, programs | ||
|
|
||
|
|
||
| def _linear_bias_decomposition(input, weight, bias=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we naturally decompose this, if its not being decomposed its because the aoti backend probably specifies not to decompose it. So we should probably just disable that for metal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your comment made me realize I should add a comment explaining this. This decomposition is the key to making things work. Linear does get decomposed, but not in a way that works for us. When bias is present, and tensors are 2D, this gets decomposed into addmm. This requires us to implement addmm in the Metal backend, but more importantly, reinterpret_tensor_wrapper gets called on the bias, to make it look like a 2D. That eventually makes its way to ExecuTorch, as a call to executorch::extension::from_blob with a 0 stride (to view a 1D tensor as a 2D tensor). ExecuTorch doesn't currently support that, and raises and error.
This decomposition avoids that problem, and also avoids having to implement addmm.
| @autoreleasepool { | ||
| try { | ||
| // Convert AOTITensorHandle to ExecutorTorch tensors | ||
| auto out_tensor = reinterpret_cast<Tensor*>(out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: It would be useful to wrap this cast in a helper function since we plan on changing the tensor definition backing the shim.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that would be better suited for a separate PR, since all other ops do it this way. Maybe just leave for when I do the migration to SlimTensor.
|
|
||
| // Create cache key for this batched matrix multiplication | ||
| // Cache key includes: op_name, shape params {B, M, K, N}, dtype, transpose_flag | ||
| // This allows reuse when same BMM shape/dtype is called repeatedly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this happen often with transformers? I would think the shape mostly changes each iteration with growing context length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am seeing bmm being called on encode, not on decode. I am seeing 72 calls to bmm on the inductor-generated code for encoding, with what seems to be 2 different shape combinations. So, this caching would be avoiding the redundant creation of 70 MPSGraphs.
| (void)device; // Used for validation, consistent with other ops | ||
|
|
||
| // Get Metal buffers for input and output tensors | ||
| id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_bmm_out", "self"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just sort of curious why do you pass the fn name and arg name, are these keys to a dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for logging
| ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal device"); | ||
| return Error::Internal; | ||
| } | ||
| (void)device; // Used for validation, consistent with other ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whats going on here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unnecessary call to get_metal_device, cleaned up
|
|
||
| // Validate tensor dimensions - bmm requires 3-D tensors | ||
| if (self_tensor->dim() != 3 || mat2_tensor->dim() != 3 || out_tensor->dim() != 3) { | ||
| ET_LOG(Error, "aoti_torch_mps_bmm_out: tensors must be 3-D. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the failure is that dims is wrong why is it useful to print the shape. Especially since dims could be too big?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaned up (originally introduced while debugging)
mergennachin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See inline comments
| (void)new_handle; | ||
| throw std::runtime_error("Not implemented"); | ||
| return Error::Internal; | ||
| ET_LOG(Debug, "aoti_torch_new_tensor_handle: entered"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should you handle both zero and non-zero offset in this method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean? aoti_torch_new_tensor_handle doesn't take an offset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mergennachin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
This pull request builds on top of #16499 to introduce support for the Parakeet model in the Metal backend. The most important changes are grouped below:
Parakeet export/lowering:
Operator updates:
aoti_torch_mps_bmm_outto support batched matrix multiplication (bmm) in the Metal backendaoti_torch_mps_convolutionby reading the correct dimension from the weight tensor.Shim layer updates:
aoti_torch_new_tensor_handleaoti_torch__reinterpret_tensorby adjusting the data pointer instead of rejecting non-zero offsets, and updating memory tracking and Metal buffer mapping logic accordingly.metal_buffer_nocopyfunction to map arbitrary memory pointers into Metal buffers, supporting cases where the data pointer is offset.