|
1 | 1 | import math |
2 | 2 | import os |
3 | 3 | import pickle |
4 | | -from itertools import repeat |
5 | 4 | from pathlib import Path |
6 | 5 | from typing import BinaryIO |
7 | 6 |
|
@@ -82,30 +81,56 @@ def _write_index_segment(file_descriptor: BinaryIO, index_list: list[tuple[int, |
82 | 81 | def _write_data_segment( |
83 | 82 | file_descriptor: BinaryIO, token_data: list[np.ndarray], token_size_in_bytes: int, write_batch_size: int |
84 | 83 | ) -> list[tuple[int, int]]: |
85 | | - def encoded_token_to_bytes(encoded_token: int, token_size_in_bytes: int) -> bytes: |
86 | | - # Converts an token_ids to its byte representation. |
87 | | - try: |
88 | | - token_bytes = encoded_token.to_bytes(token_size_in_bytes, byteorder="little", signed=False) |
89 | | - except OverflowError as e: |
90 | | - raise ValueError(f"Token {encoded_token} cannot be represented by {token_size_in_bytes} bytes.") from e |
91 | | - return token_bytes |
92 | | - |
93 | | - samples = [] |
94 | | - index_list = [] |
| 84 | + # Fast path: vectorized cast + tobytes (no per-token Python work). |
| 85 | + # Preserves little-endian unsigned representation and overflow checks. |
| 86 | + |
| 87 | + if token_size_in_bytes == 1: |
| 88 | + dtype = np.dtype("u1") |
| 89 | + elif token_size_in_bytes == 2: |
| 90 | + dtype = np.dtype("<u2") # force little-endian |
| 91 | + elif token_size_in_bytes == 4: |
| 92 | + dtype = np.dtype("<u4") # force little-endian |
| 93 | + else: |
| 94 | + raise ValueError("Currently only support token byte sizes of 1, 2, and 4.") |
| 95 | + |
| 96 | + max_allowed = 2 ** (8 * token_size_in_bytes) - 1 |
| 97 | + |
| 98 | + samples: list[bytes] = [] |
| 99 | + index_list: list[tuple[int, int]] = [] |
95 | 100 | curr_offset = 0 |
| 101 | + pending = 0 |
| 102 | + |
96 | 103 | for sample_tokens in token_data: |
97 | | - # convert token_ids to byte representation |
98 | | - sample_token_byte_string = b"".join( |
99 | | - map(encoded_token_to_bytes, sample_tokens.tolist(), repeat(token_size_in_bytes)) |
100 | | - ) |
| 104 | + arr = np.asarray(sample_tokens) |
| 105 | + |
| 106 | + # ---- Overflow / range check (preserves original semantics) ---- |
| 107 | + if arr.size: |
| 108 | + min_val = int(arr.min()) |
| 109 | + max_val = int(arr.max()) |
| 110 | + if min_val < 0 or max_val > max_allowed: |
| 111 | + raise ValueError( |
| 112 | + f"Token values out of range for {token_size_in_bytes} bytes: " |
| 113 | + f"min={min_val}, max={max_val}, allowed=[0, {max_allowed}]" |
| 114 | + ) |
| 115 | + # ---------------------------------------------------------------- |
| 116 | + |
| 117 | + # Cast to correct unsigned little-endian dtype |
| 118 | + arr = np.asarray(arr, dtype=dtype, order="C") |
| 119 | + sample_token_byte_string = arr.tobytes(order="C") |
| 120 | + |
101 | 121 | samples.append(sample_token_byte_string) |
102 | 122 | index_list.append((curr_offset, len(sample_token_byte_string))) |
103 | 123 | curr_offset += len(sample_token_byte_string) |
104 | | - if len(samples) % write_batch_size == 0: |
| 124 | + |
| 125 | + pending += 1 |
| 126 | + if pending >= write_batch_size: |
105 | 127 | file_descriptor.write(b"".join(samples)) |
106 | | - samples = [] |
| 128 | + samples.clear() |
| 129 | + pending = 0 |
| 130 | + |
107 | 131 | if len(samples) > 0: |
108 | 132 | file_descriptor.write(b"".join(samples)) |
| 133 | + |
109 | 134 | return index_list |
110 | 135 |
|
111 | 136 | @staticmethod |
|
0 commit comments