diff --git a/acquire/outputs/tar.py b/acquire/outputs/tar.py index 69b72906..a463d80b 100644 --- a/acquire/outputs/tar.py +++ b/acquire/outputs/tar.py @@ -1,7 +1,10 @@ from __future__ import annotations +import copy import io +import shutil import tarfile +from logging import getLogger from typing import TYPE_CHECKING, BinaryIO from acquire.crypt import EncryptedStream @@ -13,6 +16,7 @@ from dissect.target.filesystem import FilesystemEntry TAR_COMPRESSION_METHODS = {"gzip": "gz", "bzip2": "bz2", "xz": "xz"} +log = getLogger(__name__) class TarOutput(Output): @@ -100,7 +104,65 @@ def write( if stat: info.mtime = stat.st_mtime - self.tar.addfile(info, fh) + # Inline version of Python stdlib's tarfile.addfile & tarfile.copyfileobj, + # to allow for padding and more control over the tar file writing. + self.tar._check("awx") + + if fh is None and info.isreg() and info.size != 0: + return + + tarinfo = copy.copy(info) + + saved_offset = self.tar.offset + saved_filepos = self.tar.fileobj.tell() + + try: + buf = tarinfo.tobuf(self.tar.format, self.tar.encoding, self.tar.errors) + self.tar.fileobj.write(buf) + self.tar.offset += len(buf) + + if fh is not None: + # Start of tarfile.copyfileobj + bufsize = self.tar.copybufsize or 16 * 1024 + if tarinfo.size == 0: + return + if tarinfo.size is None: + shutil.copyfileobj(fh, self.tar.fileobj, bufsize) + return + + blocks, remainder = divmod(tarinfo.size, bufsize) + for _ in range(blocks): + buf = fh.read(bufsize) + if len(buf) < bufsize: + # PATCH; instead of raising an exception, pad the data to the desired length + buf += tarfile.NUL * (bufsize - len(buf)) + self.tar.fileobj.write(buf) + + if remainder > 0: + buf = fh.read(remainder) + if len(buf) < remainder: + # PATCH; instead of raising an exception, pad the data to the desired length + buf += tarfile.NUL * (remainder - len(buf)) + self.tar.fileobj.write(buf) + # End of tarfile.copyfileobj + + blocks, remainder = divmod(tarinfo.size, tarfile.BLOCKSIZE) + if remainder > 0: + self.tar.fileobj.write(tarfile.NUL * (tarfile.BLOCKSIZE - remainder)) + blocks += 1 + self.tar.offset += blocks * tarfile.BLOCKSIZE + + self.tar.members.append(tarinfo) + except Exception: + log.warning( + "An error occurred while writing to the tar file. " + "Truncating to the last known good state (offset: %d).", + saved_filepos, + ) + self.tar.fileobj.seek(saved_filepos) + self.tar.fileobj.truncate() + self.tar.offset = saved_offset + raise def close(self) -> None: """Closes the tar file.""" diff --git a/tests/test_outputs_tar.py b/tests/test_outputs_tar.py index 81059bb2..1b512919 100644 --- a/tests/test_outputs_tar.py +++ b/tests/test_outputs_tar.py @@ -1,5 +1,7 @@ from __future__ import annotations +import io +import logging import tarfile from pathlib import Path from typing import TYPE_CHECKING @@ -63,3 +65,122 @@ def test_tar_output_encrypt(mock_fs: VirtualFilesystem, public_key: bytes, tmp_p with tarfile.open(name=decrypted_path, mode="r") as tar_file: assert entry.open().read() == tar_file.extractfile(entry_name).read() + + +def test_tar_output_race_condition_with_shrinking_file(tmp_path: Path, public_key: bytes) -> None: + """Test that the tar output correctly handles a file that shrinks while being read.""" + + class ShrinkingFile(io.BytesIO): + def __init__(self, data: bytes): + super().__init__(data) + + def read(self, size: int) -> bytes: + return super().read(size - 5) + + content = b"some text" + + content_padded = content[:-5] + tarfile.NUL * 5 + file = ShrinkingFile(content) + + tar_output = TarOutput(tmp_path / "race.tar", encrypt=True, public_key=public_key) + tar_output.write("file.log", file) + tar_output.close() + file.close() + + encrypted_stream = EncryptedFile(tar_output.path.open("rb"), Path("tests/_data/private_key.pem")) + decrypted_path = tmp_path / "decrypted.tar" + + # Direct streaming is not an option because tarfile needs seek when reading from encrypted files directly + Path(decrypted_path).write_bytes(encrypted_stream.read()) + + with tarfile.open(name=decrypted_path, mode="r") as tar_file: + member = tar_file.getmember("file.log") + extracted = tar_file.extractfile(member).read() + # The content should be padded with zeros to match the original size, despite the fact that the file shrunk + assert extracted == content_padded + + +def test_tar_output_race_condition_with_growing_file(tmp_path: Path, public_key: bytes) -> None: + """Test that the tar output correctly handles a file that grows while being read.""" + + class GrowingFile(io.BytesIO): + def __init__(self, data: bytes): + super().__init__(data) + + def read(self, size: int) -> bytes: + return super().read(size) + b"FOX" + + content = b"some text" + + file = GrowingFile(content) + + tar_output = TarOutput(tmp_path / "race.tar", encrypt=True, public_key=public_key) + tar_output.write("file.log", file) + tar_output.close() + file.close() + + encrypted_stream = EncryptedFile(tar_output.path.open("rb"), Path("tests/_data/private_key.pem")) + decrypted_path = tmp_path / "decrypted.tar" + + # Direct streaming is not an option because tarfile needs seek when reading from encrypted files directly + Path(decrypted_path).write_bytes(encrypted_stream.read()) + + with tarfile.open(name=decrypted_path, mode="r") as tar_file: + member = tar_file.getmember("file.log") + extracted = tar_file.extractfile(member).read() + # The content should match the original content, without the extra bytes + # because the file was read with the original size + assert extracted == content + + +def test_tar_output_exception_rollback(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: + """Test that tar file is properly truncated when an exception occurs during writing.""" + + class FailingFile(io.BytesIO): + def __init__(self, data: bytes, fail_after_bytes: int = 5): + super().__init__(data) + self.fail_after_bytes = fail_after_bytes + self.bytes_read = 0 + + def read(self, size: int) -> bytes: + data = super().read(size) + self.bytes_read += len(data) + if self.bytes_read > self.fail_after_bytes: + raise IOError("Simulated I/O error during file read") + return data + + content = b"This is some test content that will fail during reading" + failing_file = FailingFile(content, fail_after_bytes=5) + + tar_output = TarOutput(tmp_path / "test.tar") + + successful_file = io.BytesIO(b"dissectftw") + tar_output.write("successful_file.txt", successful_file) + + file_size_before_failure = tar_output.tar.fileobj.tell() + members_count_before = len(tar_output.tar.members) + with ( + caplog.at_level(logging.WARNING, logger="acquire.outputs.tar"), + pytest.raises(IOError, match="Simulated I/O error during file read"), + ): + # Attempt to write the failing file + tar_output.write("failing_file.txt", failing_file, size=len(content)) + + # Check that a warning was logged about the error and the rollback + assert len(caplog.records) == 1 + assert caplog.records[0].message == ( + "An error occurred while writing to the tar file. Truncating to the last known good state (offset: 1024)." + ) + + # Verify that the tar file was truncated back to its state before the failed write + assert tar_output.tar.fileobj.tell() == file_size_before_failure + assert len(tar_output.tar.members) == members_count_before + + tar_output.close() + + # Verify the tar file can still be opened and contains only the successful entry + with tarfile.open(tar_output.path) as tar_file: + members = tar_file.getmembers() + assert len(members) == 1 + assert members[0].name == "successful_file.txt" + assert tar_file.extractfile("successful_file.txt").read() == b"dissectftw"