⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content

Conversation

@vthumbe1503
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 2 commits January 5, 2026 18:11
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
vthumbe1503 and others added 4 commits January 6, 2026 12:34
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review January 7, 2026 17:22
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Greptile Overview

Greptile Summary

This PR implements CPU-side optimizations for FP8 quantized tensor operations. The main changes include:

  • Property caching: Added cached properties for dtype, requires_grad, shape, and is_cuda in QuantizedTensor and subclasses to avoid expensive Python attribute lookups
  • Direct C API calls: Replaced pybind11 calls with direct Python C API (PyObject_Call, PyDict_SetItemString) in quantizer tensor creation to reduce overhead
  • Symbol caching: Added caching for CUDA driver symbols in cuda_driver.h to avoid repeated symbol lookups
  • GEMM check caching: Cached nvte_is_non_tn_fp8_gemm_supported() result to avoid redundant calls
  • Stride pre-computation: Computing strides in C++ rather than calling Python functions
  • Thread-safe initialization: Using std::call_once for extension initialization

The optimizations are well-targeted at reducing Python/C++ boundary overhead. However, several critical memory safety and correctness issues from previous review threads remain unaddressed and must be fixed before merging.

Confidence Score: 1/5

  • This PR has critical memory safety bugs that will cause memory leaks and potential crashes in production
  • Score reflects multiple critical P0 issues in quantizer.cpp: memory leaks from unreleased PyObject references (12+ instances), potential use-after-free bugs with temporary py::object lifetimes, and missing NULL checks for Python C API calls. These are serious memory safety violations that must be fixed
  • Pay critical attention to transformer_engine/pytorch/csrc/quantizer.cpp - contains multiple memory leaks and safety issues that must be resolved before merge

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/quantizer.cpp Critical memory safety issues: memory leaks from PyTuple_New(0) not being decremented, potential use-after-free from temporary py::object in PyDict_SetItemString, missing NULL checks for Python C API calls, exception safety issues with manual memory management
transformer_engine/pytorch/quantized_tensor.py Added property caching for dtype and requires_grad with lazy initialization fallback. Potential cache staleness if PyTorch modifies requires_grad internally, but implementation correctly syncs cache with parent tensor
transformer_engine/pytorch/csrc/extensions/pybind.cpp Replaced individual null checks with std::call_once for thread-safe initialization. Minor concern: individual init functions no longer guarded if called directly outside init_extension()
transformer_engine/pytorch/cpp_extensions/gemm.py Reordered device attribute checks to prioritize quantized tensor attributes. Optimizes quantized tensors but adds overhead for regular tensors (4 extra hasattr calls)

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Linear as Linear Module
    participant Quantizer as Float8Quantizer (C++)
    participant PyAPI as Python C API
    participant Tensor as QuantizedTensor
    
    User->>Linear: forward(input, weight)
    Note over Linear: Cache requires_grad checks
    Linear->>Linear: inp_requires_grad = inp.requires_grad
    Linear->>Linear: weight_requires_grad = weight.requires_grad
    
    Linear->>Quantizer: create_tensor(shape, dtype)
    Note over Quantizer: Cache GEMM support check
    Quantizer->>Quantizer: is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported()
    
    Note over Quantizer: Direct C API (bypassing pybind11)
    Quantizer->>Quantizer: stride = stride_from_shape(shape)
    Quantizer->>PyAPI: PyDict_New(), PyTuple_New(0)
    Quantizer->>PyAPI: PyDict_SetItemString(kwargs, ...)
    Quantizer->>PyAPI: PyObject_Call(Float8TensorClass, args, kwargs)
    Quantizer->>PyAPI: Py_DECREF(kwargs), Py_DECREF(args)
    
    PyAPI-->>Tensor: Float8Tensor instance
    Note over Tensor: Cached properties initialized
    Tensor->>Tensor: _dtype = dtype
    Tensor->>Tensor: _requires_grad = requires_grad
    
    Tensor-->>User: Quantized tensor
    
    User->>Tensor: tensor.dtype
    Note over Tensor: Return cached _dtype (no PyObject lookup)
    Tensor-->>User: dtype
    
    User->>Tensor: tensor.shape
    Note over Tensor: Return _data.shape (cached property)
    Tensor-->>User: shape
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/pytorch/csrc/util.cpp, line 18-20 (link)

    logic: Critical logical error: || should be &&. This condition will always betruesince a value cannot simultaneously be both scaling modes, causing the function to always return nullopt for valid inputs.

  2. transformer_engine/pytorch/quantized_tensor.py, line 373-393 (link)

    style: commented-out code for requires_grad caching optimization - consider removing dead code entirely. Is this code planned to be implemented later or should it be removed?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  3. transformer_engine/pytorch/module/linear.py, line 484 (link)

    logic: Logical error: this condition should use OR (||) not AND (&&). The original logic was checking if ANY tensor requires gradients for FP8 handling, but this now only activates when ALL three require gradients, including bias which may be None.

    Should the FP8 condition check if any tensor requires gradients (OR logic) rather than all tensors (AND logic)?

10 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.

Key Changes:

  • Caches requires_grad, dtype, shape, and is_cuda attribute accesses to avoid expensive PyObject lookups on custom tensors
  • Reorders attribute checks in get_tensor_device() to prioritize internal quantized tensor attributes
  • Makes num_devices static in nvte_is_non_tn_fp8_gemm_supported() to cache device count
  • Stores GEMM support check results in local variables to avoid redundant function calls

Critical Issues Found:

  • Variable redeclaration error in cublaslt_gemm.cu (line 224) will prevent compilation
  • Logic bug in linear.py (line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad

Confidence Score: 0/5

  • This PR cannot be merged due to compilation error and critical logic bug
  • Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
  • Pay close attention to transformer_engine/common/gemm/cublaslt_gemm.cu (compilation error) and transformer_engine/pytorch/module/linear.py (logic bug)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/gemm/cublaslt_gemm.cu 1/5 Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure
transformer_engine/common/transformer_engine.cpp 5/5 Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization
transformer_engine/pytorch/module/linear.py 0/5 Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Linear as Linear Module
    participant Quantizer as Quantizer/QuantizedTensor
    participant GEMM as GEMM Operations
    participant CPP as C++ Extensions

    Note over Linear,CPP: Performance Optimization Flow
    
    User->>Linear: forward(input, weight, bias)
    
    Note over Linear: Cache requires_grad checks
    Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
    
    Linear->>Quantizer: Check if quantized tensor
    alt QuantizedTensor
        Note over Quantizer: Use cached dtype property
        Quantizer->>Quantizer: return self._dtype
        Note over Quantizer: Use cached shape/is_cuda
        Quantizer->>Quantizer: return self._data.shape
    else Regular Tensor
        Quantizer->>Linear: Standard attribute access
    end
    
    Linear->>CPP: get_tensor_device(tensor)
    Note over CPP: Reordered attribute checks
    CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
    CPP-->>Linear: device_index
    
    Linear->>GEMM: Configure GEMM parameters
    Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
    GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
    Note over CPP: Static num_devices cached
    CPP-->>GEMM: support_flag
    GEMM->>GEMM: Store in local variable
    
    GEMM->>GEMM: Execute optimized GEMM
    GEMM-->>Linear: output
    
    Note over Linear: FP8 State Management
    alt FP8 enabled and requires_grad check
        Linear->>Linear: Update FP8 tensors<br/>based on cached flags
    end
    
    Linear-->>User: output
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Additional Comments (2)

transformer_engine/common/gemm/cublaslt_gemm.cu
variable redeclared in same scope - already declared at line 132

    // int is_nvte_non_tn_fp8_gemm_supported already declared at line 132

transformer_engine/pytorch/module/linear.py
logic change from original requires_grad(inp, weight, bias) which returns True if ANY tensor requires grad. New code requires ALL THREE to be True, breaking FP8 state management when bias is None or doesn't require grad

            if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.

The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Move this declaration before line 125 (Configure A matrix)
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();

Comment on lines 165 to 166
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: PyTuple_New(0) creates a new Python tuple object with a reference count of 1, but this reference is never decremented. The tuple is passed directly to PyObject_Call which doesn't steal the reference, so it must be explicitly released.

This same issue occurs at lines 186, 398, 417, 697, 718, 1039, 1058, 1346, and 1367.

Suggested change
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid parameter: The Float8TensorStorage.__new__ method does not accept a device parameter. Looking at the Python class definition in float8_tensor_storage.py, the __new__ signature only expects: data, fp8_scale_inv, fp8_dtype, data_transpose, and quantizer.

This will cause a runtime error when trying to create Float8TensorStorage instances. This issue also occurs at lines 184, 396, 415, and in similar patterns for MXFP8 and other quantizers.

Suggested change
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
// Remove this line - device parameter is not accepted
// PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
bool is_extension_initialized = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.

The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:

Suggested change
bool is_extension_initialized = false;
bool is_extension_initialized = false;
std::once_flag extension_init_flag;

And in init_extension():

void init_extension() {
  std::call_once(extension_init_flag, []() {
    init_float8_extension();
    init_mxfp8_extension();
    init_float8blockwise_extension();
    init_nvfp4_extensions();
  });
}

Comment on lines 964 to 966
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.

While dynamic device changes are rare, consider adding a bounds check:

Suggested change
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count");

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration

Comment on lines 146 to 150
py::object scale_inv_py = py::cast(scale_inv);
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: scale_inv_py is created on line 146 by dereferencing scale_inv (which is std::optional<at::Tensor>), but scale_inv is only initialized on lines 148-150 if it's empty. When scale_inv is not provided to this function, line 146 will dereference an empty optional, causing undefined behavior or a crash.

The fix is to move line 146 after the initialization block:

Suggested change
py::object scale_inv_py = py::cast(scale_inv);
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
py::object scale_inv_py = py::cast(*scale_inv);

PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: PyTuple_New(0) creates a new tuple with reference count 1, but PyObject_Call does not steal the reference. The tuple must be stored and explicitly DECREF'd after the call. This pattern is repeated 12 times throughout this file (lines 166, 186, 398, 417, 697, 718, 1039, 1058, 1346, 1367).

Suggested change
PyTuple_New(0), kwargs);
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args, kwargs);
Py_DECREF(args);

vthumbe1503 and others added 3 commits January 11, 2026 19:12
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 13 comments

Edit Code Review Agent Settings | Greptile

ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Reuse is_nvte_non_tn_fp8_gemm_supported from line 132

}
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();

py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Use-after-free bug: scale_inv_py is created from scale_inv before scale_inv is initialized. When scale_inv parameter is empty (lines 148-150), this dereferences an empty std::optional, causing undefined behavior. Move this line after the scale_inv initialization.

Suggested change
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
// Move scale_inv_py creation after initialization (line 150)

PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) creates a new Python object with refcount 1, but it's never decremented. This leaks memory on every tensor creation. Add Py_DECREF for the tuple:

Suggested change
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);
Py_DECREF(args);

PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr());
PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

return self._data.shape
if self._transpose is not None:
transpose_shape = self._transpose.shape
return tuple(transpose_shape[1:]) + (transpose_shape[0],)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Returns tuple instead of torch.Size when computed from transpose, while self._data.shape returns torch.Size. Consider wrapping with torch.Size() for type consistency:

Suggested change
return tuple(transpose_shape[1:]) + (transpose_shape[0],)
return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],))

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 19, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

vthumbe1503 and others added 2 commits January 19, 2026 22:25
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

3 similar comments
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

… at::empty

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Return the high precision data type of the tensor
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since dtype for a tensor is never
change after creation, we cache it in a member variable and return
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix this comment. dtype can get changed when custom_tensor.data is modified using set

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, 15 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +363 to +365
# Calculate stride from shape if not provided. When creating this object from
# C++ code, we provide the stride computed from shape in C++ to avoid the
# PyobjectVectorCall overhead of calling _stride_from_shape from C++ to Python.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cached _dtype and _requires_grad are only initialized here in __new__, but QuantizedTensor can be created through other paths (unpickling, tensor ops, FSDP state loading) that may bypass this initialization. The property getters include fallback logic with hasattr() checks, but this adds overhead to the optimization.

ensure all creation paths properly initialize these cached attributes, or document that the fallback path is intentional

Comment on lines +395 to +396
@dtype.setter
def dtype(self, value: torch.dtype) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dtype setter allows changing dtype after creation with only a warning. For quantized tensors, changing _dtype without re-quantizing the underlying FP8/MXFP8/NVFP4 data creates a dangerous mismatch between the cached dtype and actual data representation.

consider raising an error instead of warning, or trigger re-quantization

Comment on lines +192 to +208
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);
if (result == nullptr) {
PyErr_Print();
}
Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical memory leaks and exception safety issues:

  1. PyTuple_New(0) creates a tuple with refcount=1, but Py_DECREF only happens after the call. If PyObject_Call fails and returns NULL, the program terminates via NVTE_CHECK with the refs properly released, but this is a fatal path
  2. No NULL checks after PyDict_New() or PyTuple_New() - if allocation fails, subsequent operations will crash
  3. Exception safety: if any py::cast() throws between allocation and cleanup, kwargs and args leak
  4. py::cast(this->dtype).ptr() creates a temporary object that's destroyed after the statement, but PyDict_SetItemString should increment the refcount. This pattern is fragile

use pybind11's py::dict and py::tuple for RAII-based memory management:

py::dict kwargs;
kwargs["data"] = data_py;
kwargs["fp8_scale_inv"] = scale_inv_py;
kwargs["fp8_dtype"] = py::cast(this->dtype);
kwargs["data_transpose"] = transpose_py;
kwargs["quantizer"] = this->quantizer;

py::object result = py::reinterpret_steal<py::object>(
    PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), 
                  py::tuple().ptr(), kwargs.ptr()));
NVTE_CHECK(result.ptr() != nullptr, "Failed to create Float8TensorStorage instance");

this same pattern repeats throughout the file at lines 214-234, 382-404, and others

Comment on lines +62 to +73
at::Tensor empty_cuda(std::vector<int64_t> sizes_vector, const at::TensorOptions& opts) {
// at::empty under the hood calls at::native::empty_cuda, however it introduces a device guard
// which is called again in at::native::empty_cuda anyway. So directly calling that to avoid the
// double device guard overhead.
auto sizes = at::ArrayRef<int64_t>(sizes_vector);
auto device_opt = opts.device_opt();
c10::optional<at::ScalarType> dtype_opt = c10::nullopt;
if (opts.dtype_opt().has_value()) {
dtype_opt = opts.dtype_opt()->toScalarType();
}
return at::native::empty_cuda(sizes, dtype_opt, opts.layout_opt(), device_opt,
opts.pinned_memory_opt(), opts.memory_format_opt());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the empty_cuda wrapper bypasses the standard at::empty dispatch by calling at::native::empty_cuda directly to avoid double device guard overhead. However, this skips the dispatch mechanism that handles tensor subclasses, hooks, and custom allocators.

verify this doesn't break compatibility with PyTorch features

Comment on lines 79 to 87
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
return tensor._columnwise_data.device.index
if hasattr(tensor, "_data") and tensor._data is not None:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reordering optimizes for quantized tensors but creates performance regression for regular tensors. Before, regular torch.Tensor objects checked device immediately (1 attribute check). Now they check 4 non-existent quantized attributes via hasattr() first.

if get_tensor_device() is frequently called with regular tensors, consider adding an early type check:

if type(tensor).__name__ == 'Tensor':
    return tensor.device.index if hasattr(tensor, 'device') else torch.cuda.current_device()

Comment on lines +42 to +57

static std::unordered_map<std::string, void *> symbol_cache;
static std::mutex cache_mutex;
FuncT *func;

{
std::lock_guard<std::mutex> lock(cache_mutex);
auto it = symbol_cache.find(symbol);
if (it == symbol_cache.end()) {
void *ptr = get_symbol(symbol);
symbol_cache[symbol] = ptr;
func = reinterpret_cast<FuncT *>(ptr);
} else {
func = reinterpret_cast<FuncT *>(it->second);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excellent optimization - caching CUDA driver symbols with proper mutex synchronization avoids repeated get_symbol() calls. The mutex correctly protects both reads and writes to symbol_cache from concurrent threads

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +123 to +126
int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling
if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) {
is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

caching nvte_is_non_tn_fp8_gemm_supported() result avoids redundant calls throughout the GEMM configuration for both A and B matrices. Clean optimization with proper scoping

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +132 to +134
inp_requires_grad = inp.requires_grad
weight_requires_grad = weight.requires_grad
bias_requires_grad = bias.requires_grad if bias is not None else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

caching requires_grad checks early avoids redundant attribute lookups throughout the forward pass (used at lines 147, 260, 385, 454, 480, 482, 485). The OR logic at line 485 is semantically equivalent to the original requires_grad(inp, weight, bias) helper

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 913 to +930

@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
if self._data is not None:
return self._data.shape
if self._transpose is not None:
transpose_shape = self._transpose.shape
return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],))
raise RuntimeError("Both data and transpose are None")

@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
if self._data is not None:
return self._data.is_cuda
if self._transpose is not None:
return self._transpose.is_cuda
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cached shape and is_cuda properties include proper error handling with RuntimeError when both _data and _transpose are None. This prevents AttributeError in edge cases and provides clear error messages

Comment on lines +932 to 933
dtype = inp.dtype
# All checks after this have already been performed once, thus skip
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

caching inp.dtype early avoids a second attribute lookup at line 933 (originally line 936). Clean micro-optimization with no functional change

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 25, 2026

Additional Comments (1)

transformer_engine/pytorch/tensor/mxfp8_tensor.py
the cached shape and is_cuda properties include proper error handling with RuntimeError when both _rowwise_data and _columnwise_data are None, matching the pattern in Float8Tensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant