Skip to content

Commit 1e06d05

Browse files
author
SamoraHunter
committed
Fix RecursionError in logger and filter tqdm progress from logs
1 parent 4165475 commit 1e06d05

1 file changed

Lines changed: 33 additions & 4 deletions

File tree

ml_grid/util/logger_setup.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class TeeWriter:
8888

8989
def __init__(self, original_stream, log_file_path):
9090
self.original_stream = original_stream
91+
self.log_file_path = log_file_path # Store path for pickling
9192
self.log_file = open(log_file_path, "a", buffering=1, encoding="utf-8")
9293
# Preserve attributes from original stream
9394
self.encoding = getattr(original_stream, "encoding", "utf-8")
@@ -106,10 +107,12 @@ def write(self, message: str) -> int:
106107
written = len(message)
107108

108109
# Then append to log file
109-
try:
110-
self.log_file.write(message)
111-
except Exception:
112-
pass
110+
# Filter out progress bar updates (containing carriage return) to reduce I/O overhead
111+
if "\r" not in message:
112+
try:
113+
self.log_file.write(message)
114+
except Exception:
115+
pass
113116

114117
return written
115118

@@ -143,8 +146,34 @@ def fileno(self):
143146

144147
def __getattr__(self, name):
145148
"""Delegate any other attributes to the original stream."""
149+
# This check prevents infinite recursion if original_stream is missing
150+
if name == "original_stream":
151+
raise AttributeError(
152+
f"'{type(self).__name__}' object has no attribute 'original_stream'"
153+
)
154+
155+
if not hasattr(self, "original_stream"):
156+
raise AttributeError(
157+
f"'{type(self).__name__}' object has no attribute 'original_stream'"
158+
)
159+
146160
return getattr(self.original_stream, name)
147161

162+
def __getstate__(self):
163+
state = self.__dict__.copy()
164+
if "original_stream" in state:
165+
del state["original_stream"]
166+
if "log_file" in state:
167+
del state["log_file"]
168+
return state
169+
170+
def __setstate__(self, state):
171+
self.__dict__.update(state)
172+
self.original_stream = sys.__stdout__
173+
self.log_file = open(
174+
self.log_file_path, "a", buffering=1, encoding="utf-8"
175+
)
176+
148177
sys.stdout = TeeWriter(original_stdout, stdout_log)
149178
sys.stderr = TeeWriter(original_stderr, stderr_log)
150179

0 commit comments

Comments
 (0)