Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion acquire/outputs/tar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import copy
import io
import shutil
import tarfile
from logging import getLogger
from typing import TYPE_CHECKING, BinaryIO

from acquire.crypt import EncryptedStream
Expand All @@ -13,6 +16,7 @@
from dissect.target.filesystem import FilesystemEntry

TAR_COMPRESSION_METHODS = {"gzip": "gz", "bzip2": "bz2", "xz": "xz"}
log = getLogger(__name__)


class TarOutput(Output):
Expand Down Expand Up @@ -100,7 +104,65 @@ def write(
if stat:
info.mtime = stat.st_mtime

self.tar.addfile(info, fh)
# Inline version of Python stdlib's tarfile.addfile & tarfile.copyfileobj,
# to allow for padding and more control over the tar file writing.
self.tar._check("awx")

if fh is None and info.isreg() and info.size != 0:
return

tarinfo = copy.copy(info)

saved_offset = self.tar.offset
saved_filepos = self.tar.fileobj.tell()

try:
buf = tarinfo.tobuf(self.tar.format, self.tar.encoding, self.tar.errors)
self.tar.fileobj.write(buf)
self.tar.offset += len(buf)

if fh is not None:
# Start of tarfile.copyfileobj
bufsize = self.tar.copybufsize or 16 * 1024
if tarinfo.size == 0:
return
if tarinfo.size is None:
shutil.copyfileobj(fh, self.tar.fileobj, bufsize)
return

blocks, remainder = divmod(tarinfo.size, bufsize)
for _ in range(blocks):
buf = fh.read(bufsize)
if len(buf) < bufsize:
# PATCH; instead of raising an exception, pad the data to the desired length
buf += tarfile.NUL * (bufsize - len(buf))
self.tar.fileobj.write(buf)

if remainder > 0:
buf = fh.read(remainder)
if len(buf) < remainder:
# PATCH; instead of raising an exception, pad the data to the desired length
buf += tarfile.NUL * (remainder - len(buf))
self.tar.fileobj.write(buf)
# End of tarfile.copyfileobj

blocks, remainder = divmod(tarinfo.size, tarfile.BLOCKSIZE)
if remainder > 0:
self.tar.fileobj.write(tarfile.NUL * (tarfile.BLOCKSIZE - remainder))
blocks += 1
self.tar.offset += blocks * tarfile.BLOCKSIZE

self.tar.members.append(tarinfo)
except Exception:
log.warning(
"An error occurred while writing to the tar file. "
"Truncating to the last known good state (offset: %d).",
saved_filepos,
)
self.tar.fileobj.seek(saved_filepos)
Comment thread
joost-j marked this conversation as resolved.
self.tar.fileobj.truncate()
self.tar.offset = saved_offset
raise

def close(self) -> None:
"""Closes the tar file."""
Expand Down
121 changes: 121 additions & 0 deletions tests/test_outputs_tar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import io
import logging
import tarfile
from pathlib import Path
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -63,3 +65,122 @@ def test_tar_output_encrypt(mock_fs: VirtualFilesystem, public_key: bytes, tmp_p

with tarfile.open(name=decrypted_path, mode="r") as tar_file:
assert entry.open().read() == tar_file.extractfile(entry_name).read()


def test_tar_output_race_condition_with_shrinking_file(tmp_path: Path, public_key: bytes) -> None:
Comment thread
joost-j marked this conversation as resolved.
"""Test that the tar output correctly handles a file that shrinks while being read."""

class ShrinkingFile(io.BytesIO):
def __init__(self, data: bytes):
super().__init__(data)

def read(self, size: int) -> bytes:
return super().read(size - 5)

content = b"some text"

content_padded = content[:-5] + tarfile.NUL * 5
file = ShrinkingFile(content)

tar_output = TarOutput(tmp_path / "race.tar", encrypt=True, public_key=public_key)
tar_output.write("file.log", file)
tar_output.close()
file.close()

encrypted_stream = EncryptedFile(tar_output.path.open("rb"), Path("tests/_data/private_key.pem"))
decrypted_path = tmp_path / "decrypted.tar"

# Direct streaming is not an option because tarfile needs seek when reading from encrypted files directly
Path(decrypted_path).write_bytes(encrypted_stream.read())

with tarfile.open(name=decrypted_path, mode="r") as tar_file:
member = tar_file.getmember("file.log")
extracted = tar_file.extractfile(member).read()
# The content should be padded with zeros to match the original size, despite the fact that the file shrunk
assert extracted == content_padded


def test_tar_output_race_condition_with_growing_file(tmp_path: Path, public_key: bytes) -> None:
"""Test that the tar output correctly handles a file that grows while being read."""

class GrowingFile(io.BytesIO):
def __init__(self, data: bytes):
super().__init__(data)

def read(self, size: int) -> bytes:
return super().read(size) + b"FOX"

content = b"some text"

file = GrowingFile(content)

tar_output = TarOutput(tmp_path / "race.tar", encrypt=True, public_key=public_key)
tar_output.write("file.log", file)
tar_output.close()
file.close()

encrypted_stream = EncryptedFile(tar_output.path.open("rb"), Path("tests/_data/private_key.pem"))
decrypted_path = tmp_path / "decrypted.tar"

# Direct streaming is not an option because tarfile needs seek when reading from encrypted files directly
Path(decrypted_path).write_bytes(encrypted_stream.read())

with tarfile.open(name=decrypted_path, mode="r") as tar_file:
member = tar_file.getmember("file.log")
extracted = tar_file.extractfile(member).read()
# The content should match the original content, without the extra bytes
# because the file was read with the original size
assert extracted == content


def test_tar_output_exception_rollback(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
"""Test that tar file is properly truncated when an exception occurs during writing."""

class FailingFile(io.BytesIO):
def __init__(self, data: bytes, fail_after_bytes: int = 5):
super().__init__(data)
self.fail_after_bytes = fail_after_bytes
self.bytes_read = 0

def read(self, size: int) -> bytes:
data = super().read(size)
self.bytes_read += len(data)
if self.bytes_read > self.fail_after_bytes:
raise IOError("Simulated I/O error during file read")
return data

content = b"This is some test content that will fail during reading"
failing_file = FailingFile(content, fail_after_bytes=5)

tar_output = TarOutput(tmp_path / "test.tar")

successful_file = io.BytesIO(b"dissectftw")
tar_output.write("successful_file.txt", successful_file)

file_size_before_failure = tar_output.tar.fileobj.tell()
members_count_before = len(tar_output.tar.members)
with (
caplog.at_level(logging.WARNING, logger="acquire.outputs.tar"),
pytest.raises(IOError, match="Simulated I/O error during file read"),
):
# Attempt to write the failing file
tar_output.write("failing_file.txt", failing_file, size=len(content))

# Check that a warning was logged about the error and the rollback
assert len(caplog.records) == 1
assert caplog.records[0].message == (
"An error occurred while writing to the tar file. Truncating to the last known good state (offset: 1024)."
)

# Verify that the tar file was truncated back to its state before the failed write
assert tar_output.tar.fileobj.tell() == file_size_before_failure
assert len(tar_output.tar.members) == members_count_before

tar_output.close()

# Verify the tar file can still be opened and contains only the successful entry
with tarfile.open(tar_output.path) as tar_file:
members = tar_file.getmembers()
assert len(members) == 1
assert members[0].name == "successful_file.txt"
assert tar_file.extractfile("successful_file.txt").read() == b"dissectftw"
Loading