⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content

Conversation

@Edwardf0t1
Copy link
Contributor

@Edwardf0t1 Edwardf0t1 commented Jan 9, 2026

What does this PR do?

Type of change: New feature

Overview:

The primary goal of this PR is to allow the model optimizer to use image-text pair data during the calibration phase of quantization, which is likely help improve accuracy of quantized VLMs like Nemotron VL on visual understanding tasks particularly, compared to text-only calibration data.

  • New Feature: Adds support for VLM calibration specifically using image-text data.
  • Dataset Integration: Introduces support for sampling from the Nemotron-VLM-Dataset-v2.
  • Refactoring: Created a separate utility for VLM datasets to keep the main Hugging Face PTQ script (hf_ptq.py) clean.
  • Simplified logic for handling multimodal inputs.
  • Addressed specific issues encountered when calibrating the Nemotron-Nano-VL-12B-V2 model with image data.
  • Documentation: Updated the README to include instructions and examples for VLM calibration.

This PR complements #347 and we will consolidate llm_ptq and vlm_ptq examples in follow-up PRs.

Usage

python3 hf_ptq.py   --pyt_ckpt_path /home/scratch.omniml_data_2/models/Nemotron-Nano-VL-12B-V2   --qformat nvfp4   --export_path /home/omniml_data_3/zhiyuc/checkpoints/Nemotron-Nano-VL-12B-V2-NVFP4-doccalib   --trust_remote_code   --kv_cache_qformat none --calib_with_images   --vlm_dataset nemotron_vlm_dataset_v2   --vlm_subsets sparsetables,plotqa_cot   --calib_size 512

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Not yet

Additional Information

Summary by CodeRabbit

  • New Features

    • Added Vision-Language Model (VLM) calibration support with image-text pair data, specifically for Nemotron VL models.
    • Added new --calib_with_images CLI flag to enable image-based calibration workflows.
    • Integrated Nemotron VLM dataset v2 for streaming multimodal calibration data.
  • Documentation

    • Added VLM calibration guidance in the PTQ README with usage examples and dataset information.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 9, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@codecov
Copy link

codecov bot commented Jan 9, 2026

Codecov Report

❌ Patch coverage is 10.12270% with 293 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.08%. Comparing base (3036a9e) to head (b313f93).

Files with missing lines Patch % Lines
modelopt/torch/utils/vlm_dataset_utils.py 8.76% 177 Missing ⚠️
modelopt/torch/utils/nemotron_vlm_dataset_utils.py 12.12% 116 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #755      +/-   ##
==========================================
- Coverage   74.13%   73.08%   -1.05%     
==========================================
  Files         192      193       +1     
  Lines       19263    19583     +320     
==========================================
+ Hits        14280    14312      +32     
- Misses       4983     5271     +288     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Edwardf0t1 Edwardf0t1 self-assigned this Jan 14, 2026
@Edwardf0t1 Edwardf0t1 marked this pull request as ready for review January 14, 2026 01:16
@Edwardf0t1 Edwardf0t1 requested review from a team as code owners January 14, 2026 01:16
@shengliangxu
Copy link
Contributor

shengliangxu commented Jan 14, 2026

So, we only support image quantization for just nemotron-vl? If yes, why?

# limitations under the License.

"""Utility functions for getting samples and forward loop function for different vlm datasets."""
"""Utility functions for getting samples and dataloader for different VLM calibration datasets.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajrasane could you review this change?

@cjluo-nv
Copy link
Collaborator

@Edwardf0t1 do you have experiments evaluating the accuracy impact of using the new dataset?

@Edwardf0t1
Copy link
Contributor Author

So, we only support image quantization for just nemotron-vl? If yes, why?

At this time, only Nemotron VL has been tested. We can extend the logic to support other VLMs later. Note that different VLMs may have different forward functions—e.g., the way the vision encoder interacts with the language decoder can vary across models.

Do you have a preferred VL model you’d like us to support next? For instance, Qwen3-VL?

@Edwardf0t1
Copy link
Contributor Author

@Edwardf0t1 do you have experiments evaluating the accuracy impact of using the new dataset?

Tested on two benchmarks DocVQA and InfoVQA for Nemotron Nano VL v2 with vLLM backend:

  • BF16 Baseline: 94.2184, 79.1404
  • NVFP4 text-only calibration: 93.9472, 77.7221
  • NVFP4 image-text calibration: 94.0854, 77.9598

Image-text calibration is only marginally better in these cases, but the calibration flow in this PR should be ready. The follow-up experiments can be

  1. Choose different subsets in Nemotron-VLM-Dataset-v2 or another image-text dataset for calibration
  2. Check more evaluation metrics.
  3. Run benchmarks on other VLMs such as Nemotron Parse, Qwen3-VL.

# prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# inputs = processor(text=[prompt], images=[pil_image], ...)

def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to introduce these while the original one does not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we don't use image-text data for calibration, and standard dataLoader collation doesn't work for VLMs. A few reasons:

  • Dataset has inconsistent image formats
  • We need to convert conversational format to model input format.
  • Processor must process images and text together to align properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we create a class for this collate function?

class VLMCollator:
    def __init__(self, processor, dataset_name, require_image, max_length, device):
        self.processor = processor
        self.repo_id = (
            SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"]["path"]
            if dataset_name == "nemotron_vlm_dataset_v2"
            else None
        )
        self.image_root = getattr(processor, "_modelopt_vlm_image_root", None)
        self.require_image = require_image
        self.max_length = max_length
        self.device = device

    def __call__(self, examples):
        # ... the collate logic

This would make it more readable and easier to test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a class is needed here. This _collate_fn is already tightly scoped to get_vlm_dataset_dataloader() and only depends on a small set of captured variables. Making it a class would add boilerplate without real benefit unless you need stateful caching, metrics, or config reuse across multiple dataloaders.

Copy link
Contributor

@jingyu-ml jingyu-ml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I only reviewed the dataset processing part, which behaves as expected, loading the dataset on demand rather than downloading the entire dataset.

# prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# inputs = processor(text=[prompt], images=[pil_image], ...)

def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we create a class for this collate function?

class VLMCollator:
    def __init__(self, processor, dataset_name, require_image, max_length, device):
        self.processor = processor
        self.repo_id = (
            SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"]["path"]
            if dataset_name == "nemotron_vlm_dataset_v2"
            else None
        )
        self.image_root = getattr(processor, "_modelopt_vlm_image_root", None)
        self.require_image = require_image
        self.max_length = max_length
        self.device = device

    def __call__(self, examples):
        # ... the collate logic

This would make it more readable and easier to test.

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
…for Nemotron-VLM-Dataset-v2

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
…for Nemotron-VLM-Dataset-v2

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 24, 2026

📝 Walkthrough

Walkthrough

This pull request introduces Vision-Language Model (VLM) calibration support for post-training quantization. It adds new dataset utilities for streaming Nemotron VLM data, implements image-text pair calibration loops, extends the quantization pipeline to handle multimodal models, and includes documentation and helper functions for Nemotron VL model processing.

Changes

Cohort / File(s) Summary
Documentation & VLM Integration
examples/llm_ptq/README.md, examples/llm_ptq/hf_ptq.py
Added VLM calibration documentation section with dataset reference. Extended quantization pipeline with VL model detection, image-based calibration path via --calib_with_images flag, language model extraction from VL models, and tokenizer/padding state management across quantization/export routines.
Calibration Loop Helpers
examples/llm_ptq/example_utils.py, examples/llm_ptq/nemotron_vl_calib.py
Introduced create_vlm_calibration_loop() to dynamically adapt calibration loops based on model forward signatures. Added safe_nemotron_vl_forward() for safe multimodal forward passes that extract image/text data, align vision embeddings with token positions, and trigger quantizer calibration without wrapper overhead.
VLM Dataset Infrastructure
modelopt/torch/utils/nemotron_vlm_dataset_utils.py, modelopt/torch/utils/vlm_dataset_utils.py
Created new module for Nemotron VLM dataset streaming with tar-plus-JSONL handling, cached file listing, and image extraction from messages. Extended existing VLM dataset utilities with HuggingFace streaming dataset support, generic processor collation, image loading from local/remote sources, and new nemotron_vlm_dataset_v2 configuration with subset handling.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant hf_ptq
    participant ModelLoader
    participant VLMProcessor
    participant DataLoader
    participant CalibLoop
    participant Quantizer

    User->>hf_ptq: Execute with --calib_with_images
    hf_ptq->>ModelLoader: load_model() with calib_with_images=True
    ModelLoader->>ModelLoader: Detect Nemotron VL model
    ModelLoader->>VLMProcessor: Create AutoProcessor
    VLMProcessor->>VLMProcessor: Configure padding tokens & side
    ModelLoader->>ModelLoader: extract_and_prepare_language_model_from_vl()
    ModelLoader->>hf_ptq: Return LM + default_pad_token
    
    hf_ptq->>DataLoader: Load nemotron_vlm_dataset_v2
    DataLoader->>DataLoader: Stream tar shards + JSONL
    DataLoader->>DataLoader: Match images to messages
    DataLoader->>hf_ptq: Yield {id, messages, image}
    
    hf_ptq->>CalibLoop: create_vlm_calibration_loop(model, dataloader)
    CalibLoop->>CalibLoop: Inspect model.forward signature
    
    loop Per batch
        CalibLoop->>CalibLoop: Extract pixel_values, input_ids, attention_mask
        CalibLoop->>CalibLoop: safe_nemotron_vl_forward()
        CalibLoop->>CalibLoop: Align vision embeddings with img_context_token_id
        CalibLoop->>CalibLoop: Run LM forward (no grad, eval mode)
    end
    
    hf_ptq->>Quantizer: quantize_main() with calibrated stats
    Quantizer->>hf_ptq: Export quantized LM
    hf_ptq->>hf_ptq: Restore tokenizer.pad_token
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~65 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 56.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main feature addition: enabling vision-language model calibration using image-text pair data.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants