-
Notifications
You must be signed in to change notification settings - Fork 798
WIP: Enable Parakeet on Metal #16499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16499
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 1 Unrelated FailureAs of commit fc45eed with merge base 55fe42b ( NEW FAILURES - The following jobs have failed:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
f39750d to
385acb3
Compare
|
@JacobSzwejbka @manuelcandales @Gasoonjia Here's the latest update:
|
|
Here's where the stacktrace during assertion https://gist.github.com/mergennachin/8a43299f4e74c16da3dd2946275737cf |
385acb3 to
bfe5b77
Compare
|
I have been looking at the inductor generated code, and this is the line from the generated code that causes the error: The error is produced by that first call to
So, this turns into a call to make_tensor_ptr with data containing |
bfe5b77 to
71c25a8
Compare
71c25a8 to
fc45eed
Compare
|
Closing in lieu of #16562 |
This pull request builds on top of #16499 to introduce support for the Parakeet model in the Metal backend. The most important changes are grouped below: ### Parakeet export/lowering: * Added support for Metal lowering to the Parakeet export/lowering script * Provided custom linear decomposition. This achieved two objectives: avoided call to addmm, and avoided call to reinterpret_tensor_wrapper with 0 stride ### Operator updates: * Added implementation for `aoti_torch_mps_bmm_out` to support batched matrix multiplication (bmm) in the Metal backend * Fixed input channel dimension handling for grouped convolutions in `aoti_torch_mps_convolution` by reading the correct dimension from the weight tensor. ### Shim layer updates: * Added implementation for `aoti_torch_new_tensor_handle` * Enabled non-zero tensor storage offsets in `aoti_torch__reinterpret_tensor` by adjusting the data pointer instead of rejecting non-zero offsets, and updating memory tracking and Metal buffer mapping logic accordingly. * Added the `metal_buffer_nocopy` function to map arbitrary memory pointers into Metal buffers, supporting cases where the data pointer is offset. * Improved error messages in several stubbed shim functions by including the function name in the exception message for easier debugging.
No description provided.