Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp
index 2fe136fb..d024834b 100644
--- a/src/torchcodec/_core/Encoder.cpp
+++ b/src/torchcodec/_core/Encoder.cpp
@@ -846,7 +846,8 @@ void VideoEncoder::initializeEncoder(
if (videoStreamOptions.pixelFormat.has_value()) {
// TODO-VideoEncoder: (P2) Enable pixel formats to be set by user on GPU
// and handled with the appropriate NPP function on GPU.
- if (frames_.device().type() == kStableCUDA) {
+ if (frames_.device().type() == kStableCUDA ||
+ frames_.device().type() == kStableXPU) {
STD_TORCH_CHECK(
false,
"Video encoding on GPU currently only supports the nv12 pixel format. "
@@ -855,7 +856,8 @@ void VideoEncoder::initializeEncoder(
outPixelFormat =
validatePixelFormat(*avCodec, videoStreamOptions.pixelFormat.value());
} else {
- if (frames_.device().type() == kStableCUDA) {
+ if (frames_.device().type() == kStableCUDA ||
+ frames_.device().type() == kStableXPU) {
// Default to nv12 pixel format when encoding on GPU.
outPixelFormat = DeviceInterface::CUDA_ENCODING_PIXEL_FORMAT;
} else {
@@ -910,7 +912,8 @@ void VideoEncoder::initializeEncoder(
0);
}

- if (frames_.device().type() == kStableCUDA) {
+ if (frames_.device().type() == kStableCUDA ||
+ frames_.device().type() == kStableXPU) {
deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get());
deviceInterface_->setupHardwareFrameContextForEncoding(
avCodecContext_.get());
@@ -1208,7 +1211,7 @@ void MultiStreamEncoder::initializeVideoStream() {
if (videoStream.options.pixelFormat.has_value()) {
// TODO-MultiStreamEncoder: (P2) Enable pixel formats to be set by user on
// GPU and handled with the appropriate NPP function on GPU.
- if (deviceType == kStableCUDA) {
+ if (deviceType == kStableCUDA || deviceType == kStableXPU) {
STD_TORCH_CHECK(
false,
"Video encoding on GPU currently only supports the nv12 pixel format. "
@@ -1217,7 +1220,7 @@ void MultiStreamEncoder::initializeVideoStream() {
outPixelFormat =
validatePixelFormat(*avCodec, videoStream.options.pixelFormat.value());
} else {
- if (deviceType == kStableCUDA) {
+ if (deviceType == kStableCUDA || deviceType == kStableXPU) {
// Default to nv12 pixel format when encoding on GPU.
outPixelFormat = DeviceInterface::CUDA_ENCODING_PIXEL_FORMAT;
} else {
@@ -1275,7 +1278,7 @@ void MultiStreamEncoder::initializeVideoStream() {
0);
}

- if (deviceType == kStableCUDA) {
+ if (deviceType == kStableCUDA || deviceType == kStableXPU) {
videoStream.deviceInterface->registerHardwareDeviceWithCodec(
videoStream.avCodecContext.get());
videoStream.deviceInterface->setupHardwareFrameContextForEncoding(
143 changes: 143 additions & 0 deletions packages/torchcodec-xpu/src/torchcodec_xpu/ColorConversionKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ namespace facebook::torchcodec {

using float3x3 = std::array<sycl::float3, 3>;

// ============================================================
// Decoding matrices: YCbCr -> RGB (used by NV12toRGBKernel)
// ============================================================
struct rgb_matrix {
static constexpr float3x3 BT709 = {
sycl::float3{ 1.0, 0.0, 1.5748 },
Expand All @@ -23,6 +26,26 @@ struct rgb_matrix {
};
};

// ============================================================
// Encoding matrices: RGB -> YCbCr (used by RGB24toNV12Kernel)
// Inverse of rgb_matrix above.
// Row 0: Y coefficients
// Row 1: Cb coefficients
// Row 2: Cr coefficients
// ============================================================
struct yuv_matrix {
static constexpr float3x3 BT709 = {
sycl::float3{ 0.212600f, 0.715200f, 0.072200f }, // Y
sycl::float3{ -0.114572f, -0.385428f, 0.5f }, // Cb
sycl::float3{ 0.5f, -0.454153f, -0.045847f } // Cr
};
static constexpr float3x3 BT601 = {
sycl::float3{ 0.299f, 0.587f, 0.114f }, // Y
sycl::float3{ -0.168736f, -0.331264f, 0.5f }, // Cb
sycl::float3{ 0.5f, -0.418688f, -0.081312f} // Cr
};
};

// Helper function for the Intel Tile-Y offset calculation
// Intel Y-Tiling uses COLUMN-MAJOR OWord (16 bytes) organization
// Tile: 128 bytes wide × 32 rows = 4KB
Expand Down Expand Up @@ -166,6 +189,13 @@ const float3x3 getColorConversionMatrix(enum AVColorSpace colorspace) {
return rgb_matrix::BT601;
}

const float3x3 getYUVConversionMatrix(enum AVColorSpace colorspace) {
if (colorspace == AVCOL_SPC_BT709) {
return yuv_matrix::BT709;
}
return yuv_matrix::BT601;
}

void convertNV12ToRGB(
sycl::queue& queue,
const uint8_t* y_plane,
Expand Down Expand Up @@ -201,5 +231,118 @@ void registerColorConversionKernel() {
(void)s;
}

// ============================================================
// Encoding kernel: NCHW RGB tensor -> NV12 VAAPI surface
// ============================================================
struct RGB24toNV12Kernel {
const uint8_t* rgb_nchw; // CHW uint8 device pointer (R, G, B planes)
int64_t ch_stride; // stride between channel planes
int64_t row_stride; // stride between rows within a plane
int64_t pixel_stride; // stride between adjacent pixels (1 for NCHW, 3 for HWC-permuted)
uint8_t* y_plane;
uint8_t* uv_plane;
int width;
int height;
int y_pitch; // surface Y-plane row pitch in bytes
int uv_pitch; // surface UV-plane row pitch in bytes
bool is_tiled; // true → Tile-Y; false → linear
bool fullrange;
float3x3 yuv_mat;

RGB24toNV12Kernel(
const uint8_t* rgb_nchw_,
int64_t ch_stride_,
int64_t row_stride_,
int64_t pixel_stride_,
uint8_t* y_plane_,
uint8_t* uv_plane_,
int width_,
int height_,
int y_pitch_,
int uv_pitch_,
bool is_tiled_,
bool fullrange_,
const float3x3& yuv_mat_)
: rgb_nchw(rgb_nchw_),
ch_stride(ch_stride_),
row_stride(row_stride_),
pixel_stride(pixel_stride_),
y_plane(y_plane_),
uv_plane(uv_plane_),
width(width_),
height(height_),
y_pitch(y_pitch_),
uv_pitch(uv_pitch_),
is_tiled(is_tiled_),
fullrange(fullrange_),
yuv_mat(yuv_mat_)
{}

void operator()(sycl::id<2> idx) const {
int x = idx[1];
int y = idx[0];

if (x >= width || y >= height) {
return;
}

// Read RGB from NCHW tensor.
float r = rgb_nchw[0 * ch_stride + y * row_stride + x * pixel_stride] / 255.0f;
float g = rgb_nchw[1 * ch_stride + y * row_stride + x * pixel_stride] / 255.0f;
float b = rgb_nchw[2 * ch_stride + y * row_stride + x * pixel_stride] / 255.0f;
sycl::float3 src{r, g, b};

// Luma Y — write to Tile-Y or linear destination
float Y_norm = sycl::dot(src, yuv_mat[0]);
float Y = fullrange ? Y_norm * 255.0f : 16.0f + Y_norm * 219.0f;
size_t y_dst = is_tiled ? get_tile_offset(x, y, y_pitch)
: (size_t)y * y_pitch + x;
y_plane[y_dst] = (uint8_t)std::clamp(Y, 0.0f, 255.0f);

// Chroma UV: one pair per 2x2 block (NV12 4:2:0 subsampling).
if ((x % 2 == 0) && (y % 2 == 0)) {
float Cb_norm = sycl::dot(src, yuv_mat[1]);
float Cr_norm = sycl::dot(src, yuv_mat[2]);
float U = fullrange ? Cb_norm * 255.0f + 128.0f : 128.0f + Cb_norm * 224.0f;
float V = fullrange ? Cr_norm * 255.0f + 128.0f : 128.0f + Cr_norm * 224.0f;
size_t u_dst = is_tiled ? get_tile_offset(x, y / 2, uv_pitch)
: (size_t)(y / 2) * uv_pitch + x;
size_t v_dst = is_tiled ? get_tile_offset(x + 1, y / 2, uv_pitch)
: (size_t)(y / 2) * uv_pitch + x + 1;
uv_plane[u_dst] = (uint8_t)std::clamp(U, 0.0f, 255.0f);
uv_plane[v_dst] = (uint8_t)std::clamp(V, 0.0f, 255.0f);
}
}
};

void convertRGBToNV12(
sycl::queue& queue,
const uint8_t* rgb_nchw,
int64_t ch_stride,
int64_t row_stride,
int64_t pixel_stride,
uint8_t* dst_y,
uint8_t* dst_uv,
int width,
int height,
int y_pitch,
int uv_pitch,
bool is_tiled,
enum AVColorRange color_range,
enum AVColorSpace colorspace) {
bool fullrange = (color_range == AVCOL_RANGE_JPEG);
queue.submit([&](sycl::handler& cgh) {
RGB24toNV12Kernel kernel(
rgb_nchw, ch_stride, row_stride, pixel_stride,
dst_y, dst_uv,
width, height,
y_pitch, uv_pitch,
is_tiled,
fullrange, getYUVConversionMatrix(colorspace));
cgh.parallel_for(sycl::range<2>(height, width), kernel);
});
queue.wait();
}

} // namespace facebook::torchcodec
#endif // WITH_SYCL_KERNELS
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ void convertNV12ToRGB(
enum AVColorRange color_range,
enum AVColorSpace colorspace);

// Encoding: NCHW uint8 RGB tensor (on XPU) -> NV12 VAAPI surface.
// is_tiled: true for Intel Tile-Y surfaces (drm_format_modifier != 0), false for linear.
void convertRGBToNV12(
sycl::queue& queue,
const uint8_t* rgb_nchw,
int64_t ch_stride,
int64_t row_stride,
int64_t pixel_stride,
uint8_t* dst_y,
uint8_t* dst_uv,
int width,
int height,
int y_pitch,
int uv_pitch,
bool is_tiled,
enum AVColorRange color_range,
enum AVColorSpace colorspace);

// Anchor function to force kernel registration
void registerColorConversionKernel();

Expand Down
Loading
Loading