Skip to content

Commit 5b076f9

Browse files
committed
relink.py: Check that files to process are children of CESM input data root.
1 parent ecc6b9e commit 5b076f9

9 files changed

Lines changed: 162 additions & 43 deletions

relink.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import argparse
1111
import logging
1212
import time
13+
from pathlib import Path
1314

1415
DEFAULT_SOURCE_ROOT = "/glade/campaign/cesm/cesmdata/cseg/inputdata/"
1516
DEFAULT_TARGET_ROOT = (
@@ -60,18 +61,23 @@ def _handle_non_dir_entry(entry, user_uid):
6061
return None
6162

6263

63-
def handle_non_dir(var, user_uid):
64+
def handle_non_dir(var, user_uid, inputdata_root):
6465
"""
6566
Check if a non-directory is owned by the user and should be processed. Passes var to a
6667
helper function depending on its type.
6768
6869
Args:
6970
var (os.DirEntry or str): A directory entry from os.scandir(), or a string path.
7071
user_uid (int): The UID of the user whose files to find.
72+
inputdata_root (str): The root of the directory tree containing CESM input data.
7173
7274
Returns:
7375
str or None: The absolute path to the file if it's owned by the user
7476
and is a regular file (not a symlink), otherwise None.
77+
78+
Raises:
79+
TypeError: If var is not a DirEntry-like object.
80+
ValueError: If the file path is not under inputdata_root.
7581
"""
7682

7783
# Fall back to duck typing: If var has the required DirEntry methods and members, treat it as a
@@ -80,12 +86,22 @@ def handle_non_dir(var, user_uid):
8086
if isinstance(var, os.DirEntry) or all(
8187
hasattr(var, m) for m in ["stat", "is_file", "is_symlink", "path"]
8288
):
83-
return _handle_non_dir_entry(var, user_uid)
89+
file_path = _handle_non_dir_entry(var, user_uid)
90+
else:
91+
raise TypeError(
92+
f"Unsure how to handle non-directory variable of type {type(var)}"
93+
)
8494

85-
raise TypeError(f"Unsure how to handle non-directory variable of type {type(var)}")
95+
# Check that resulting path is a child of inputdata_root
96+
if file_path is not None and not Path(file_path).is_relative_to(inputdata_root):
97+
raise ValueError(
98+
f"'{file_path}' must be equivalent to or under '{inputdata_root}"
99+
)
100+
101+
return file_path
86102

87103

88-
def find_owned_files_scandir(directory, user_uid):
104+
def find_owned_files_scandir(directory, user_uid, inputdata_root=DEFAULT_SOURCE_ROOT):
89105
"""
90106
Efficiently find all files owned by a specific user using os.scandir().
91107
@@ -95,20 +111,28 @@ def find_owned_files_scandir(directory, user_uid):
95111
Args:
96112
directory (str): The root directory to search.
97113
user_uid (int): The UID of the user whose files to find.
114+
inputdata_root (str): The root of the directory tree containing CESM input data.
98115
99116
Yields:
100117
str: Absolute paths to files owned by the user.
118+
119+
Raises:
120+
ValueError: If any file found is not under inputdata_root.
101121
"""
102122
try:
103123
with os.scandir(directory) as entries:
104124
for entry in entries:
105125
try:
106126
# Recursively process directories (not following symlinks)
107127
if entry.is_dir(follow_symlinks=False):
108-
yield from find_owned_files_scandir(entry.path, user_uid)
128+
yield from find_owned_files_scandir(
129+
entry.path, user_uid, inputdata_root
130+
)
109131

110132
# Things other than directories are handled separately
111-
elif (entry_path := handle_non_dir(entry, user_uid)) is not None:
133+
elif (
134+
entry_path := handle_non_dir(entry, user_uid, inputdata_root)
135+
) is not None:
112136
yield entry_path
113137

114138
except (OSError, PermissionError) as e:
@@ -119,7 +143,9 @@ def find_owned_files_scandir(directory, user_uid):
119143
logger.debug("Error accessing %s: %s. Skipping.", directory, e)
120144

121145

122-
def replace_files_with_symlinks(source_dir, target_dir, username, dry_run=False):
146+
def replace_files_with_symlinks(
147+
source_dir, target_dir, username, inputdata_root=DEFAULT_SOURCE_ROOT, dry_run=False
148+
):
123149
"""
124150
Finds files owned by a specific user in a source directory tree,
125151
deletes them, and replaces them with symbolic links to the same
@@ -128,6 +154,7 @@ def replace_files_with_symlinks(source_dir, target_dir, username, dry_run=False)
128154
Args:
129155
source_dir (str): The root of the directory tree to search for files.
130156
target_dir (str): The root of the directory tree containing the new files.
157+
inputdata_root (str): The root of the directory tree containing CESM input data.
131158
username (str): The name of the user whose files will be processed.
132159
dry_run (bool): If True, only show what would be done without making changes.
133160
"""
@@ -152,7 +179,7 @@ def replace_files_with_symlinks(source_dir, target_dir, username, dry_run=False)
152179
)
153180

154181
# Use efficient scandir-based search
155-
for file_path in find_owned_files_scandir(source_dir, user_uid):
182+
for file_path in find_owned_files_scandir(source_dir, user_uid, inputdata_root):
156183
logger.info("Found owned file: %s", file_path)
157184

158185
# Determine the relative path and the new link's destination
@@ -251,6 +278,16 @@ def parse_arguments():
251278
),
252279
)
253280

281+
# The root of the directory tree containing CESM input data.
282+
# ONLY INTENDED FOR USE IN TESTING
283+
parser.add_argument(
284+
"--inputdata-root",
285+
"-inputdata", # to match rimport
286+
type=validate_directory,
287+
default=DEFAULT_SOURCE_ROOT,
288+
help=argparse.SUPPRESS,
289+
)
290+
254291
# Verbosity options (mutually exclusive)
255292
verbosity_group = parser.add_mutually_exclusive_group()
256293
verbosity_group.add_argument(
@@ -311,7 +348,11 @@ def main():
311348

312349
# --- Execution ---
313350
replace_files_with_symlinks(
314-
args.source_root, args.target_root, my_username, dry_run=args.dry_run
351+
args.source_root,
352+
args.target_root,
353+
my_username,
354+
inputdata_root=args.inputdata_root,
355+
dry_run=args.dry_run,
315356
)
316357

317358
if args.timing:

tests/relink/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import shutil
88

99
import pytest
10+
from unittest.mock import patch
1011

1112

1213
@pytest.fixture(scope="function", name="temp_dirs")
@@ -15,7 +16,8 @@ def fixture_temp_dirs():
1516
source_dir = tempfile.mkdtemp(prefix="test_source_")
1617
target_dir = tempfile.mkdtemp(prefix="test_target_")
1718

18-
yield source_dir, target_dir
19+
with patch("relink.DEFAULT_SOURCE_ROOT", source_dir):
20+
yield source_dir, target_dir
1921

2022
# Cleanup
2123
shutil.rmtree(source_dir, ignore_errors=True)

tests/relink/test_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121

2222

2323
@pytest.fixture(scope="function", name="mock_default_dirs")
24-
def fixture_mock_default_dirs():
24+
def fixture_mock_default_dirs(temp_dirs):
2525
"""Mock the default directories to use temporary directories."""
2626
source_dir = tempfile.mkdtemp(prefix="test_default_source_")
2727
target_dir = tempfile.mkdtemp(prefix="test_default_target_")
28+
source_dir, target_dir = temp_dirs
2829

2930
with patch.object(relink, "DEFAULT_SOURCE_ROOT", source_dir):
3031
with patch.object(relink, "DEFAULT_TARGET_ROOT", target_dir):

tests/relink/test_cmdline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def test_command_line_execution_dry_run(mock_dirs):
4444
"--target-root",
4545
str(target_dir),
4646
"--dry-run",
47+
"--inputdata-root",
48+
str(source_dir),
4749
]
4850

4951
# Execute the command
@@ -78,6 +80,8 @@ def test_command_line_execution_actual_run(mock_dirs):
7880
str(source_dir),
7981
"--target-root",
8082
str(target_dir),
83+
"-inputdata",
84+
str(source_dir),
8185
]
8286

8387
# Execute the command

tests/relink/test_dryrun.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import sys
77
import logging
8+
from unittest.mock import patch
89

910
import pytest
1011

@@ -46,7 +47,7 @@ def test_dry_run_no_changes(dry_run_setup, caplog):
4647
# Run in dry-run mode
4748
with caplog.at_level(logging.INFO):
4849
relink.replace_files_with_symlinks(
49-
source_dir, target_dir, username, dry_run=True
50+
source_dir, target_dir, username, inputdata_root=source_dir, dry_run=True
5051
)
5152

5253
# Verify no changes were made
@@ -64,7 +65,7 @@ def test_dry_run_shows_message(dry_run_setup, caplog):
6465
# Run in dry-run mode
6566
with caplog.at_level(logging.INFO):
6667
relink.replace_files_with_symlinks(
67-
source_dir, target_dir, username, dry_run=True
68+
source_dir, target_dir, username, inputdata_root=source_dir, dry_run=True
6869
)
6970

7071
# Check that dry-run messages were logged
@@ -80,7 +81,7 @@ def test_dry_run_no_delete_or_create_messages(dry_run_setup, caplog):
8081
# Run in dry-run mode
8182
with caplog.at_level(logging.INFO):
8283
relink.replace_files_with_symlinks(
83-
source_dir, target_dir, username, dry_run=True
84+
source_dir, target_dir, username, inputdata_root=source_dir, dry_run=True
8485
)
8586

8687
# Verify actual operation messages are NOT logged

tests/relink/test_find_owned_files_scandir.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def test_find_owned_files_basic(temp_dirs):
103103
f.write("content2")
104104

105105
# Find owned files
106-
found_files = list(relink.find_owned_files_scandir(source_dir, user_uid))
106+
found_files = list(
107+
relink.find_owned_files_scandir(source_dir, user_uid, inputdata_root=source_dir)
108+
)
107109

108110
# Verify both files were found
109111
assert len(found_files) == 2
@@ -130,7 +132,9 @@ def test_find_owned_files_nested(temp_dirs):
130132
fp.write("content")
131133

132134
# Find owned files
133-
found_files = list(relink.find_owned_files_scandir(source_dir, user_uid))
135+
found_files = list(
136+
relink.find_owned_files_scandir(source_dir, user_uid, inputdata_root=source_dir)
137+
)
134138

135139
# Verify all files were found
136140
assert len(found_files) == 3
@@ -156,7 +160,11 @@ def test_skip_symlinks(temp_dirs, caplog):
156160

157161
# Find owned files with logging
158162
with caplog.at_level(logging.DEBUG):
159-
found_files = list(relink.find_owned_files_scandir(source_dir, user_uid))
163+
found_files = list(
164+
relink.find_owned_files_scandir(
165+
source_dir, user_uid, inputdata_root=source_dir
166+
)
167+
)
160168

161169
# Verify only regular file was found
162170
assert len(found_files) == 1
@@ -197,7 +205,11 @@ def test_skip_symlinks_owned_by_different_user(temp_dirs, caplog):
197205

198206
with patch("os.scandir", side_effect=mock_scandir):
199207
with caplog.at_level(logging.INFO):
200-
found_files = list(relink.find_owned_files_scandir(source_dir, user_uid))
208+
found_files = list(
209+
relink.find_owned_files_scandir(
210+
source_dir, user_uid, inputdata_root=source_dir
211+
)
212+
)
201213

202214
# Verify only regular file was found
203215
assert len(found_files) == 1
@@ -216,7 +228,9 @@ def test_empty_directory(temp_dirs):
216228
user_uid = os.stat(source_dir).st_uid
217229

218230
# Find owned files in empty directory
219-
found_files = list(relink.find_owned_files_scandir(source_dir, user_uid))
231+
found_files = list(
232+
relink.find_owned_files_scandir(source_dir, user_uid, inputdata_root=source_dir)
233+
)
220234

221235
# Should return empty list
222236
assert len(found_files) == 0
@@ -245,7 +259,11 @@ def test_permission_error_handling(temp_dirs, caplog):
245259
try:
246260
# Find owned files with debug logging
247261
with caplog.at_level(logging.DEBUG):
248-
found_files = list(relink.find_owned_files_scandir(source_dir, user_uid))
262+
found_files = list(
263+
relink.find_owned_files_scandir(
264+
source_dir, user_uid, inputdata_root=source_dir
265+
)
266+
)
249267

250268
# Should find the accessible file but skip the inaccessible directory
251269
assert file1 in found_files
@@ -272,7 +290,9 @@ def test_only_files_not_directories(temp_dirs):
272290
os.makedirs(subdir)
273291

274292
# Find owned files
275-
found_files = list(relink.find_owned_files_scandir(source_dir, user_uid))
293+
found_files = list(
294+
relink.find_owned_files_scandir(source_dir, user_uid, inputdata_root=source_dir)
295+
)
276296

277297
# Should only find the file, not the directory
278298
assert len(found_files) == 1
@@ -303,7 +323,11 @@ def test_does_not_follow_symlink_directories(temp_dirs):
303323
os.symlink(external_dir, symlink_dir)
304324

305325
# Find owned files
306-
found_files = list(relink.find_owned_files_scandir(source_dir, user_uid))
326+
found_files = list(
327+
relink.find_owned_files_scandir(
328+
source_dir, user_uid, inputdata_root=source_dir
329+
)
330+
)
307331

308332
# Should find file in real directory but not in symlinked directory
309333
assert file_in_real in found_files

0 commit comments

Comments
 (0)