vjawa_move_preprocessing_to_device

#8
by VibhuJawa - opened

Move YoloXWrapper.preprocess to the model device so that preprocessing runs on GPU and returns tensors already placed on self.device.

Previously, preprocess returned CPU tensors and every call incurred a host→device transfer during batching/inference. This was dominating latency when preprocessing many images per batch.

Change Details

  • Always move the image tensor to self.device inside YoloXWrapper.preprocess.
  • Keeps the public API unchanged (preprocess still accepts numpy arrays or tensors), but ensures downstream code works entirely on the model device.

Benchmarks

Script used: preprocess_benchmark.py (attached in PR), running on cuda:0 with warmup=3, iters=10.

Before (main, CPU based pre processing)

Benchmarking preprocess on device: cuda:0, warmup=3, iters=10
  BS |  mean (ms) |  p50 (ms) |  p95 (ms) |   imgs/s
----------------------------------------------------
   1 |       2.16 |      2.04 |      3.03 |   462.09
   2 |       3.87 |      3.79 |      5.15 |   516.64
   4 |      10.11 |      9.45 |     13.03 |   395.61
   8 |      19.86 |     18.86 |     25.30 |   402.85
  16 |      37.85 |     35.20 |     49.68 |   422.72
  32 |      72.42 |     67.39 |     96.91 |   441.89

After (this PR, preprocess on GPU / self.device):

Benchmarking preprocess on device: cuda:0, warmup=3, iters=10
  BS |  mean (ms) |  p50 (ms) |  p95 (ms) |   imgs/s
----------------------------------------------------
   1 |       0.29 |      0.26 |      0.46 |  3436.61
   2 |       0.89 |      0.47 |      2.73 |  2258.99
   4 |       0.92 |      0.86 |      1.12 |  4370.63
   8 |       1.87 |      1.79 |      2.15 |  4286.73
  16 |       5.16 |      4.93 |      6.59 |  3101.92
  32 |       9.85 |      9.71 |     10.54 |  3248.43

Rough improvement:

  • Latency: ~6–7× lower (mean ms, depending on batch size)
  • Throughput: ~2x–7× higher (imgs/s)

Script here:

(hf_nvingest_models) vjawa@dgx-a100-01:/raid/vjawa/hf_models/nemotron-table-structure-v1$ cat preprocess_benchmark.py
#!/usr/bin/env python3
"""Benchmark the preprocessing step for different batch sizes."""

i#!/usr/bin/env python3
"""Benchmark the preprocessing step for different batch sizes."""

import argparse
import time
from typing import Iterable, List

import numpy as np
import torch
from PIL import Image

from nemotron_graphic_elements_v1.model import define_model


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Benchmark model.preprocess for various batch sizes."
    )
    parser.add_argument(
        "--image",
        default="./example.png",
        help="Path to an RGB image to use for benchmarking.",
    )
    parser.add_argument(
        "--batch-sizes",
        type=int,
        nargs="+",
        default=[1, 2, 4, 8, 16, 32],
        help="Batch sizes to benchmark.",
    )
    parser.add_argument(
        "--warmup-iters",
        type=int,
        default=3,
        help="Number of warmup iterations to run before measuring.",
    )
    parser.add_argument(
        "--measure-iters",
        type=int,
        default=10,
        help="Number of measurement iterations per batch size.",
    )
    return parser.parse_args()


def load_image(path: str) -> np.ndarray:
    """Load an RGB image as a numpy array."""
    return np.array(Image.open(path).convert("RGB"))


def benchmark_preprocess(
    model: torch.nn.Module,
    image: np.ndarray,
    batch_sizes: Iterable[int],
    warmup_iters: int,
    measure_iters: int,
) -> List[dict]:
    results: List[dict] = []
    for bs in batch_sizes:
        # Reuse the same image; preprocess does not mutate the input.
        images = [image.copy() for _ in range(bs)]

        # Warmup
        with torch.inference_mode():
            for _ in range(warmup_iters):
                for img in images:
                    output = model.preprocess(img)

        print("Preprocess output device: ", output.device)
        # Measure
        times_s: List[float] = []
        for _ in range(measure_iters):
            start = time.perf_counter()
            with torch.inference_mode():
                batch_tensors = [model.preprocess(img) for img in images]
                _ = torch.stack(batch_tensors)
            end = time.perf_counter()
            times_s.append(end - start)

        times = np.array(times_s)
        mean_ms = float(times.mean() * 1000)
        p50_ms = float(np.percentile(times, 50) * 1000)
        p95_ms = float(np.percentile(times, 95) * 1000)
        imgs_per_sec = float(bs / times.mean())

        results.append(
            {
                "batch_size": bs,
                "mean_ms": mean_ms,
                "p50_ms": p50_ms,
                "p95_ms": p95_ms,
                "imgs_per_sec": imgs_per_sec,
            }
        )

    return results


def main() -> None:
    args = parse_args()

    image = load_image(args.image)
    model = define_model("graphic_element_v1", verbose=False)

    results = benchmark_preprocess(
        model,
        image,
        args.batch_sizes,
        args.warmup_iters,
        args.measure_iters,
    )

    print(
        f"Benchmarking preprocess on device: {model.device}, "
        f"warmup={args.warmup_iters}, iters={args.measure_iters}"
    )
    header = f"{'BS':>4} | {'mean (ms)':>10} | {'p50 (ms)':>9} | {'p95 (ms)':>9} | {'imgs/s':>8}"
    print(header)
    print("-" * len(header))
    for res in results:
        print(
            f"{res['batch_size']:>4} | "
            f"{res['mean_ms']:10.2f} | "
            f"{res['p50_ms']:9.2f} | "
            f"{res['p95_ms']:9.2f} | "
            f"{res['imgs_per_sec']:8.2f}"
        )


if __name__ == "__main__":
    main()



VibhuJawa changed pull request status to open
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment