From c50b24a5b3317a96efe5324d8e078a12d9876e1c Mon Sep 17 00:00:00 2001 From: keunwoochoi Date: Mon, 21 Jul 2025 22:10:55 -0400 Subject: [PATCH 1/9] add text streaming decoder --- requirements.txt | 1 + test/nodes/io/test_text_streaming_decoder.py | 405 +++++++++++++++++++ test/requirements.txt | 1 + torchdata/nodes/io/__init__.py | 3 + torchdata/nodes/io/text_streaming_decoder.py | 233 +++++++++++ 5 files changed, 643 insertions(+) create mode 100644 test/nodes/io/test_text_streaming_decoder.py create mode 100644 torchdata/nodes/io/__init__.py create mode 100644 torchdata/nodes/io/text_streaming_decoder.py diff --git a/requirements.txt b/requirements.txt index 14a4b8fa8..7f67a20b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ urllib3 >= 1.25 requests +smart_open diff --git a/test/nodes/io/test_text_streaming_decoder.py b/test/nodes/io/test_text_streaming_decoder.py new file mode 100644 index 000000000..db8079522 --- /dev/null +++ b/test/nodes/io/test_text_streaming_decoder.py @@ -0,0 +1,405 @@ +"""Tests for TextStreamingDecoder with various file sources and formats. + +This test suite verifies the functionality of TextStreamingDecoder across +different file sources (local, S3) and formats (plain text, compressed). + +Test coverage includes: +1. Local file operations + - Basic reading + - Metadata handling + - State management and resumption + - Empty file handling + - Text encoding (UTF-8) + - File handle cleanup + +2. S3 operations (mocked) + - Basic reading from S3 + - Using transport parameters + - State management with S3 files + +3. Compressed file handling + - Reading .gz files + - Reading .bz2 files + +4. Mixed source operations + - Reading from multiple files + - Reading from both compressed and uncompressed sources + +5. Error handling + - Invalid file paths + - Recovery from errors +""" + +import os +import tempfile +from typing import Any, Dict, List, Union +from unittest.mock import MagicMock, patch + +import pytest + +from torchdata.nodes import BaseNode +from torchdata.nodes.io.text_streaming_decoder import TextStreamingDecoder + + +class MockSourceNode(BaseNode[Dict]): + """Mock source node that provides file paths for testing.""" + + def __init__(self, file_paths: List[str], metadata: Dict[str, Any] = None): + super().__init__() + self.file_paths = file_paths + self.metadata = metadata or {} + self._current_idx = 0 + + def reset(self, initial_state=None): + super().reset(initial_state) + if initial_state is not None: + self._current_idx = initial_state.get("idx", 0) + else: + self._current_idx = 0 + + def next(self) -> Dict: + if self._current_idx >= len(self.file_paths): + raise StopIteration("No more files") + + path = self.file_paths[self._current_idx] + self._current_idx += 1 + + return {TextStreamingDecoder.DATA_KEY: path, TextStreamingDecoder.METADATA_KEY: dict(self.metadata)} + + def get_state(self): + return {"idx": self._current_idx} + + +def create_test_files(): + """Create temporary test files with known content.""" + temp_dir = tempfile.mkdtemp() + + # Create first test file + file1_path = os.path.join(temp_dir, "test1.txt") + with open(file1_path, "w") as f: + f.write("line1\nline2\nline3\n") + + # Create second test file + file2_path = os.path.join(temp_dir, "test2.txt") + with open(file2_path, "w") as f: + f.write("file2_line1\nfile2_line2\n") + + return temp_dir, [file1_path, file2_path] + + +def test_text_stream_basic(): + """Test basic functionality of TextStreamingDecoder.""" + temp_dir, file_paths = create_test_files() + try: + source_node = MockSourceNode(file_paths) + node = TextStreamingDecoder(source_node) + + # Test reading all lines + lines = [] + for item in node: + lines.append(item[TextStreamingDecoder.DATA_KEY]) + + # Check content + assert lines == ["line1", "line2", "line3", "file2_line1", "file2_line2"] + + finally: + for path in file_paths: + if os.path.exists(path): + os.remove(path) + os.rmdir(temp_dir) + + +def test_text_stream_metadata(): + """Test metadata handling in TextStreamingDecoder.""" + temp_dir, file_paths = create_test_files() + try: + source_node = MockSourceNode(file_paths, {"source": "local"}) + node = TextStreamingDecoder(source_node) + + # Get first item + item = next(iter(node)) + + # Check metadata + assert TextStreamingDecoder.METADATA_KEY in item + assert "file_path" in item[TextStreamingDecoder.METADATA_KEY] + assert item[TextStreamingDecoder.METADATA_KEY]["file_path"] == file_paths[0] + assert item[TextStreamingDecoder.METADATA_KEY]["item_idx"] == 0 + assert item[TextStreamingDecoder.METADATA_KEY]["source"] == "local" + + finally: + for path in file_paths: + if os.path.exists(path): + os.remove(path) + os.rmdir(temp_dir) + + +def test_text_stream_state_management(): + """Test state management in TextStreamingDecoder.""" + temp_dir, file_paths = create_test_files() + try: + source_node = MockSourceNode(file_paths) + node = TextStreamingDecoder(source_node) + + # Read first line and store state + first_item = next(iter(node)) + state = node.get_state() + + # Create new node and restore state + new_source = MockSourceNode(file_paths) + new_node = TextStreamingDecoder(new_source) + new_node.reset(state) + + # Read next line - should be second line + second_item = next(iter(new_node)) + + # Verify it's different from the first line + assert second_item[TextStreamingDecoder.DATA_KEY] != first_item[TextStreamingDecoder.DATA_KEY] + assert second_item[TextStreamingDecoder.METADATA_KEY]["item_idx"] == 1 + + finally: + for path in file_paths: + if os.path.exists(path): + os.remove(path) + os.rmdir(temp_dir) + + +def test_text_stream_empty_file(): + """Test handling of empty files.""" + temp_dir = tempfile.mkdtemp() + empty_file = os.path.join(temp_dir, "empty.txt") + normal_file = os.path.join(temp_dir, "normal.txt") + + try: + # Create empty file + with open(empty_file, "w") as f: + pass + + # Create normal file + with open(normal_file, "w") as f: + f.write("normal_content\n") + + source_node = MockSourceNode([empty_file, normal_file]) + node = TextStreamingDecoder(source_node) + + # Should skip empty file and read from normal file + item = next(iter(node)) + assert item[TextStreamingDecoder.DATA_KEY] == "normal_content" + + finally: + for path in [empty_file, normal_file]: + if os.path.exists(path): + os.remove(path) + os.rmdir(temp_dir) + + +def test_text_stream_encoding(): + """Test text encoding handling.""" + temp_dir = tempfile.mkdtemp() + utf8_file = os.path.join(temp_dir, "utf8.txt") + + try: + # Create file with UTF-8 content + content = "Hello 世界\n" + with open(utf8_file, "w", encoding="utf-8") as f: + f.write(content) + + source_node = MockSourceNode([utf8_file]) + node = TextStreamingDecoder(source_node, encoding="utf-8") + + # Read content + item = next(iter(node)) + assert item[TextStreamingDecoder.DATA_KEY] == "Hello 世界" + + finally: + if os.path.exists(utf8_file): + os.remove(utf8_file) + os.rmdir(temp_dir) + + +def test_text_stream_cleanup(): + """Test proper file handle cleanup.""" + temp_dir, file_paths = create_test_files() + try: + source_node = MockSourceNode(file_paths) + node = TextStreamingDecoder(source_node) + + # Read partial file + next(iter(node)) + + # Force cleanup + del node + + # Should be able to delete files (no open handles) + for path in file_paths: + os.remove(path) + + finally: + for path in file_paths: + if os.path.exists(path): + os.remove(path) + os.rmdir(temp_dir) + + +@patch("smart_open.open") +def test_s3_basic_read(mock_smart_open): + """Test basic S3 file reading with mocked smart_open.""" + # Mock smart_open for S3 + mock_file = MagicMock() + mock_file.readline.side_effect = ['{"id": 1, "text": "Hello from S3"}\n', ""] + mock_smart_open.return_value.__enter__.return_value = mock_file + + file_paths = ["s3://test-bucket/test_file1.jsonl"] + source_node = MockSourceNode(file_paths, {"source": "s3"}) + node = TextStreamingDecoder(source_node) + + # Read first line + item = next(iter(node)) + + # Should contain content + assert TextStreamingDecoder.DATA_KEY in item + assert item[TextStreamingDecoder.DATA_KEY] == '{"id": 1, "text": "Hello from S3"}' + + # Check metadata + assert TextStreamingDecoder.METADATA_KEY in item + assert "file_path" in item[TextStreamingDecoder.METADATA_KEY] + assert item[TextStreamingDecoder.METADATA_KEY]["file_path"] == "s3://test-bucket/test_file1.jsonl" + assert item[TextStreamingDecoder.METADATA_KEY]["source"] == "s3" + + +@patch("smart_open.open") +def test_compression_handling(mock_smart_open): + """Test compressed file handling.""" + # Mock smart_open for compressed file + mock_file = MagicMock() + mock_file.readline.side_effect = ["decompressed_line1\n", "decompressed_line2\n", ""] + mock_smart_open.return_value.__enter__.return_value = mock_file + + file_paths = ["s3://bucket/compressed.txt.gz"] + source_node = MockSourceNode(file_paths, {"source": "s3"}) + node = TextStreamingDecoder(source_node, transport_params={"compression": ".gz"}) + + # Read lines + lines = [item[TextStreamingDecoder.DATA_KEY] for item in node] + assert lines == ["decompressed_line1", "decompressed_line2"] + + # Verify smart_open was called with compression parameters + mock_smart_open.assert_called_with( + "s3://bucket/compressed.txt.gz", "r", encoding="utf-8", transport_params={"compression": ".gz"} + ) + + +def test_error_handling(): + """Test error handling for invalid files.""" + temp_dir = tempfile.mkdtemp() + try: + # Create a file that exists + valid_path = os.path.join(temp_dir, "valid.txt") + with open(valid_path, "w") as f: + f.write("valid content\n") + + # Define a path that doesn't exist + invalid_path = os.path.join(temp_dir, "nonexistent.txt") + + # Node should skip invalid file and read valid one + source_node = MockSourceNode([invalid_path, valid_path]) + node = TextStreamingDecoder(source_node) + item = next(iter(node)) + + # Should get content from valid file + assert item[TextStreamingDecoder.DATA_KEY] == "valid content" + + finally: + if os.path.exists(valid_path): + os.remove(valid_path) + os.rmdir(temp_dir) + + +def test_text_stream_recursive_behavior(): + """Test TextStreamingDecoder handles file transitions without recursion issues.""" + temp_dir = tempfile.mkdtemp() + try: + # Create multiple files with known content + file1_path = os.path.join(temp_dir, "test1.txt") + with open(file1_path, "w") as f: + f.write("file1_line1\nfile1_line2\n") + + file2_path = os.path.join(temp_dir, "test2.txt") + with open(file2_path, "w") as f: + f.write("file2_line1\nfile2_line2\n") + + # Create an empty file to test empty file handling + empty_file_path = os.path.join(temp_dir, "empty.txt") + with open(empty_file_path, "w") as f: + pass + + # Create a file with an error that will be skipped + error_file_path = os.path.join(temp_dir, "error.txt") + # Don't actually create this file, so it will cause an error + + source_node = MockSourceNode([file1_path, empty_file_path, error_file_path, file2_path]) + node = TextStreamingDecoder(source_node) + + # Read all lines + lines = [] + for item in node: + lines.append(item[TextStreamingDecoder.DATA_KEY]) + # Also check that metadata is correct + assert TextStreamingDecoder.METADATA_KEY in item + assert "file_path" in item[TextStreamingDecoder.METADATA_KEY] + assert "item_idx" in item[TextStreamingDecoder.METADATA_KEY] + + # Should have 4 lines total (2 from file1, 0 from empty, 0 from error, 2 from file2) + assert lines == ["file1_line1", "file1_line2", "file2_line1", "file2_line2"] + + # Check that each line is only returned once + # Reset the node + new_source = MockSourceNode([file1_path, empty_file_path, error_file_path, file2_path]) + node = TextStreamingDecoder(new_source) + + # Read lines again and check for duplicates + seen_lines = set() + for item in node: + line = item[TextStreamingDecoder.DATA_KEY] + file_path = item[TextStreamingDecoder.METADATA_KEY]["file_path"] + line_idx = item[TextStreamingDecoder.METADATA_KEY]["item_idx"] + + # Create a unique identifier for this line + line_id = (line, file_path, line_idx) + + # Check that we haven't seen this line before + assert line_id not in seen_lines, f"Duplicate line: {line_id}" + seen_lines.add(line_id) + + finally: + # Clean up + for path in [file1_path, file2_path, empty_file_path]: + if os.path.exists(path): + os.remove(path) + os.rmdir(temp_dir) + + +@patch("smart_open.open") +def test_azure_gcs_support(mock_smart_open): + """Test Azure and GCS support via smart_open.""" + # Test Azure + mock_file = MagicMock() + mock_file.readline.side_effect = ["azure_content\n", ""] + mock_smart_open.return_value.__enter__.return_value = mock_file + + azure_paths = ["abfs://container@account.dfs.core.windows.net/file.txt"] + source_node = MockSourceNode(azure_paths, {"source": "abfs"}) + node = TextStreamingDecoder(source_node, transport_params={"anon": False}) + + item = next(iter(node)) + assert item[TextStreamingDecoder.DATA_KEY] == "azure_content" + assert item[TextStreamingDecoder.METADATA_KEY]["source"] == "abfs" + + # Test GCS + mock_file.readline.side_effect = ["gcs_content\n", ""] + gcs_paths = ["gs://my-bucket/file.txt"] + source_node = MockSourceNode(gcs_paths, {"source": "gs"}) + node = TextStreamingDecoder(source_node, transport_params={"client": "mock_gcs_client"}) + + item = next(iter(node)) + assert item[TextStreamingDecoder.DATA_KEY] == "gcs_content" + assert item[TextStreamingDecoder.METADATA_KEY]["source"] == "gs" diff --git a/test/requirements.txt b/test/requirements.txt index bbd7270d9..44dcee3c6 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -8,3 +8,4 @@ adlfs awscli>=1.27.66 psutil parameterized +smart_open diff --git a/torchdata/nodes/io/__init__.py b/torchdata/nodes/io/__init__.py new file mode 100644 index 000000000..17642e343 --- /dev/null +++ b/torchdata/nodes/io/__init__.py @@ -0,0 +1,3 @@ +from .text_streaming_decoder import TextStreamingDecoder + +__all__ = ["TextStreamingDecoder"] diff --git a/torchdata/nodes/io/text_streaming_decoder.py b/torchdata/nodes/io/text_streaming_decoder.py new file mode 100644 index 000000000..aaad082b2 --- /dev/null +++ b/torchdata/nodes/io/text_streaming_decoder.py @@ -0,0 +1,233 @@ +import logging +from typing import Any, Dict, Optional, Union + +from smart_open import open +from torchdata.nodes import BaseNode + +logger = logging.getLogger(__name__) + + +class TextStreamingDecoder(BaseNode[Dict]): + """Node that streams text files line by line from any source. + + This node combines functionality of file reading and line-by-line processing, + supporting both local and remote (S3, GCS, Azure, HTTP, etc.) files via smart_open. + + Features: + - Streams files line-by-line (memory efficient) + - Supports any filesystem that smart_open supports (local, S3, GCS, Azure, HTTP, etc.) + - Handles compressed files (.gz, .bz2) transparently + - Maintains state for checkpointing and resumption + - Preserves metadata from source nodes + + Input format: + { + "data": "path/to/file.txt", # File path (local) or URI (s3://, etc.) + "metadata": {...} # Optional metadata + } + or simply a string with the file path/URI + + Output format: + { + "data": "line content", # Single line of text + "metadata": { + "file_path": "path/to/file.txt", + "item_idx": 0, # 0-based line index + ... # Additional metadata from input + } + } + + Examples: + >>> # Stream from local file + >>> node = TextStreamingDecoder(source_node) + >>> + >>> # Stream from S3 with custom client + >>> node = TextStreamingDecoder( + ... source_node, + ... transport_params={'client': boto3.client('s3')} + ... ) + >>> + >>> # Stream compressed files + >>> node = TextStreamingDecoder( + ... source_node, + ... transport_params={'compression': '.gz'} + ... ) + """ + + SOURCE_KEY = "source" + DATA_KEY = "data" + METADATA_KEY = "metadata" + CURRENT_FILE_KEY = "current_file" + + def __init__( + self, + source_node: BaseNode[Union[str, Dict]], + mode: str = "r", + encoding: Optional[str] = "utf-8", + transport_params: Optional[Dict] = None, + ): + """Initialize the TextStreamingDecoder. + + Args: + source_node: Source node that yields dicts with file paths + mode: File open mode ('r' for text, 'rb' for binary) + encoding: Text encoding (None for binary mode) + transport_params: Parameters for smart_open transport layer + For S3: + {'client': boto3.client('s3')} # Use specific client + For compression: + {'compression': '.gz'} # Force gzip compression + {'compression': '.bz2'} # Force bz2 compression + {'compression': 'disable'} # Disable compression + """ + super().__init__() + self.source = source_node + self.mode = mode + self.encoding = encoding + self.transport_params = transport_params or {} + self._current_file = None + self._current_line = 0 + self._file_handle = None + self._source_metadata = {} + + def reset(self, initial_state: Optional[Dict[str, Any]] = None): + """Reset must fully initialize the node's state. + + Args: + initial_state: Optional state dictionary for resumption + """ + super().reset(initial_state) + + # Close any open file + if self._file_handle is not None: + self._file_handle.close() + self._file_handle = None + + if initial_state is None: + # Full reset + self.source.reset(None) + self._current_file = None + self._current_line = 0 + self._source_metadata = {} + else: + # Restore source state + self.source.reset(initial_state[self.SOURCE_KEY]) + self._current_file = initial_state[self.CURRENT_FILE_KEY] + self._current_line = initial_state.get("current_line", 0) + self._source_metadata = initial_state.get(self.METADATA_KEY, {}) + + # If we have a file to resume, open and seek to position + if self._current_file is not None: + self._file_handle = open( + self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params + ) + # Skip lines to resume position + for _ in range(self._current_line): + next(self._file_handle) + + def __del__(self): + """Ensure file is closed on deletion.""" + if self._file_handle is not None: + try: + self._file_handle.close() + except Exception: + pass # Ignore errors during cleanup + + def _get_next_file(self) -> bool: + """Get the next file and open it for reading. + + Returns: + bool: True if a new file was successfully opened, False otherwise. + """ + try: + # Get next file from source + file_data = self.source.next() + + # Extract file path from data + if isinstance(file_data, dict) and self.DATA_KEY in file_data: + self._current_file = file_data[self.DATA_KEY] + # Copy metadata from source + if self.METADATA_KEY in file_data: + self._source_metadata = file_data[self.METADATA_KEY] + else: + self._current_file = file_data + self._source_metadata = {} + + try: + # Open the file + self._file_handle = open( + self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params + ) + self._current_line = 0 + return True + except Exception as e: + logger.error(f"Error opening {self._current_file}: {e}") + self._file_handle = None + return False # Failed to open file + + except StopIteration: + # No more files + raise + + def _get_next_line(self) -> Dict: + """Read the next line from the current file. + + Returns: + Dict: Dictionary with the line data and metadata. + + Raises: + StopIteration: If end of file is reached and no more files are available. + """ + try: + line = self._file_handle.readline() + + # EOF or empty line at end of file + if not line: + self._file_handle.close() + self._file_handle = None + return None # Signal end of file + + # Create output with metadata + metadata = {"file_path": self._current_file, "item_idx": self._current_line} + + # Include metadata from source + if self._source_metadata: + metadata.update(self._source_metadata) + + self._current_line += 1 + + return {self.DATA_KEY: line.rstrip("\n"), self.METADATA_KEY: metadata} + + except Exception as e: + logger.error(f"Error reading from {self._current_file}: {e}") + if self._file_handle: + self._file_handle.close() + self._file_handle = None + return None # Signal error + + def next(self) -> Dict: + """Get the next line from current file or next available file.""" + # Loop until we get a valid line or run out of files + while True: + # If we don't have a file handle, get a new one + while self._file_handle is None: + if not self._get_next_file(): + continue # Try the next file if this one failed + + # Try to get the next line + line_data = self._get_next_line() + + # If we reached the end of the file, try the next one + if line_data is None: + continue + + # We got a valid line, return it + return line_data + + def get_state(self) -> Dict[str, Any]: + """Get current state for checkpointing.""" + return { + self.SOURCE_KEY: self.source.state_dict(), + self.CURRENT_FILE_KEY: self._current_file, + "current_line": self._current_line, + } From fe9d554c409072f19dffd68fe3a744eeeaac83e1 Mon Sep 17 00:00:00 2001 From: keunwoochoi Date: Sun, 31 Aug 2025 16:53:06 -0400 Subject: [PATCH 2/9] CURRENT_LINE_KEY is added --- torchdata/nodes/io/text_streaming_decoder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchdata/nodes/io/text_streaming_decoder.py b/torchdata/nodes/io/text_streaming_decoder.py index aaad082b2..dac6e73ef 100644 --- a/torchdata/nodes/io/text_streaming_decoder.py +++ b/torchdata/nodes/io/text_streaming_decoder.py @@ -58,6 +58,7 @@ class TextStreamingDecoder(BaseNode[Dict]): DATA_KEY = "data" METADATA_KEY = "metadata" CURRENT_FILE_KEY = "current_file" + CURRENT_LINE_KEY = "current_line" def __init__( self, @@ -113,7 +114,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): # Restore source state self.source.reset(initial_state[self.SOURCE_KEY]) self._current_file = initial_state[self.CURRENT_FILE_KEY] - self._current_line = initial_state.get("current_line", 0) + self._current_line = initial_state.get(self.CURRENT_LINE_KEY, 0) self._source_metadata = initial_state.get(self.METADATA_KEY, {}) # If we have a file to resume, open and seek to position @@ -229,5 +230,5 @@ def get_state(self) -> Dict[str, Any]: return { self.SOURCE_KEY: self.source.state_dict(), self.CURRENT_FILE_KEY: self._current_file, - "current_line": self._current_line, + self.CURRENT_LINE_KEY: self._current_line, } From b778d71b62a20dff706aba93ba38c8c73ca560a2 Mon Sep 17 00:00:00 2001 From: keunwoochoi Date: Sun, 31 Aug 2025 16:58:17 -0400 Subject: [PATCH 3/9] add shutdown. --- torchdata/nodes/io/text_streaming_decoder.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchdata/nodes/io/text_streaming_decoder.py b/torchdata/nodes/io/text_streaming_decoder.py index dac6e73ef..532c5860b 100644 --- a/torchdata/nodes/io/text_streaming_decoder.py +++ b/torchdata/nodes/io/text_streaming_decoder.py @@ -1,7 +1,7 @@ import logging from typing import Any, Dict, Optional, Union -from smart_open import open +import smart_open from torchdata.nodes import BaseNode logger = logging.getLogger(__name__) @@ -119,7 +119,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): # If we have a file to resume, open and seek to position if self._current_file is not None: - self._file_handle = open( + self._file_handle = smart_open.open( self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params ) # Skip lines to resume position @@ -156,7 +156,7 @@ def _get_next_file(self) -> bool: try: # Open the file - self._file_handle = open( + self._file_handle = smart_open.open( self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params ) self._current_line = 0 @@ -232,3 +232,9 @@ def get_state(self) -> Dict[str, Any]: self.CURRENT_FILE_KEY: self._current_file, self.CURRENT_LINE_KEY: self._current_line, } + + def shutdown(self): + """Shutdown the node.""" + if self._file_handle is not None: + self._file_handle.close() + self._file_handle = None From 626de203c388d06dc2e68845f480dae95df7a659 Mon Sep 17 00:00:00 2001 From: keunwoochoi Date: Sun, 31 Aug 2025 17:01:18 -0400 Subject: [PATCH 4/9] add retry logic --- test/nodes/io/test_text_streaming_decoder.py | 181 ++++++++++++++++++- torchdata/nodes/io/text_streaming_decoder.py | 100 ++++++++-- 2 files changed, 261 insertions(+), 20 deletions(-) diff --git a/test/nodes/io/test_text_streaming_decoder.py b/test/nodes/io/test_text_streaming_decoder.py index db8079522..9a42e8369 100644 --- a/test/nodes/io/test_text_streaming_decoder.py +++ b/test/nodes/io/test_text_streaming_decoder.py @@ -28,6 +28,11 @@ 5. Error handling - Invalid file paths - Recovery from errors + +6. Retry logic + - Retry on file opening errors + - Fibonacci backoff + - Max retries configuration """ import os @@ -38,7 +43,7 @@ import pytest from torchdata.nodes import BaseNode -from torchdata.nodes.io.text_streaming_decoder import TextStreamingDecoder +from torchdata.nodes.io.text_streaming_decoder import _fibonacci_backoff, TextStreamingDecoder class MockSourceNode(BaseNode[Dict]): @@ -403,3 +408,177 @@ def test_azure_gcs_support(mock_smart_open): item = next(iter(node)) assert item[TextStreamingDecoder.DATA_KEY] == "gcs_content" assert item[TextStreamingDecoder.METADATA_KEY]["source"] == "gs" + + +def test_fibonacci_backoff(): + """Test Fibonacci backoff calculation.""" + # Test Fibonacci sequence: 1, 1, 2, 3, 5, 8, 13, 21, ... + expected_delays = [1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0] + + for attempt, expected_delay in enumerate(expected_delays, 1): + actual_delay = _fibonacci_backoff(attempt) + assert actual_delay == expected_delay, f"Attempt {attempt}: expected {expected_delay}, got {actual_delay}" + + # Test with custom base delay + assert _fibonacci_backoff(1, base_delay=2.0) == 2.0 + assert _fibonacci_backoff(2, base_delay=2.0) == 2.0 + assert _fibonacci_backoff(3, base_delay=2.0) == 4.0 + + # Test edge cases + assert _fibonacci_backoff(0) == 0.0 + assert _fibonacci_backoff(-1) == 0.0 + + +@patch("smart_open.open") +@patch("time.sleep") +def test_retry_logic_success_after_failure(mock_sleep, mock_smart_open): + """Test retry logic when it succeeds after initial failures.""" + # Mock smart_open to fail twice then succeed + mock_file = MagicMock() + mock_file.readline.side_effect = ["success_line\n", ""] + + mock_smart_open.side_effect = [ + Exception("Connection timeout"), # First attempt fails + Exception("Network error"), # Second attempt fails + mock_file, # Third attempt succeeds + ] + + file_paths = ["s3://bucket/test.txt"] + source_node = MockSourceNode(file_paths) + node = TextStreamingDecoder(source_node, max_retries=3) + + # Should succeed on third attempt + content = next(iter(node)) + assert content[TextStreamingDecoder.DATA_KEY] == "success_line" + + # Verify sleep was called twice with Fibonacci delays + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(1.0) # First retry delay + mock_sleep.assert_any_call(1.0) # Second retry delay + + +@patch("smart_open.open") +@patch("time.sleep") +def test_retry_logic_max_retries_exceeded(mock_sleep, mock_smart_open): + """Test retry logic when max retries are exceeded.""" + # Mock smart_open to always fail + mock_smart_open.side_effect = Exception("Connection timeout") + + file_paths = ["s3://bucket/test.txt"] + source_node = MockSourceNode(file_paths) + node = TextStreamingDecoder(source_node, max_retries=2) + + # Should skip the file and try the next one (if any) + # Since we only have one file and it fails, we should get StopIteration + with pytest.raises(StopIteration): + next(iter(node)) + + # Verify sleep was called twice + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(1.0) # First retry delay + mock_sleep.assert_any_call(1.0) # Second retry delay + + +@patch("smart_open.open") +@patch("time.sleep") +def test_retry_logic_zero_retries(mock_sleep, mock_smart_open): + """Test retry logic with zero retries (should fail immediately).""" + # Mock smart_open to fail + mock_smart_open.side_effect = Exception("Connection timeout") + + file_paths = ["s3://bucket/test.txt"] + source_node = MockSourceNode(file_paths) + node = TextStreamingDecoder(source_node, max_retries=0) + + # Should fail immediately without retrying + with pytest.raises(StopIteration): + next(iter(node)) + + # Verify sleep was never called + mock_sleep.assert_not_called() + + +@patch("smart_open.open") +@patch("time.sleep") +def test_retry_logic_state_restoration(mock_sleep, mock_smart_open): + """Test retry logic during state restoration.""" + # Mock smart_open to fail twice then succeed during state restoration + mock_file = MagicMock() + mock_file.readline.side_effect = ["resumed_line\n", ""] + + mock_smart_open.side_effect = [ + Exception("Connection timeout"), # First attempt fails + Exception("Network error"), # Second attempt fails + mock_file, # Third attempt succeeds + ] + + temp_dir, file_paths = create_test_files() + try: + source_node = MockSourceNode(file_paths) + node = TextStreamingDecoder(source_node, max_retries=3) + + # Read first line and store state + first_item = next(iter(node)) + state = node.get_state() + + # Create new node and restore state (this will trigger retry logic) + new_source = MockSourceNode(file_paths) + new_node = TextStreamingDecoder(new_source, max_retries=3) + new_node.reset(state) + + # Read next line - should succeed after retries + second_item = next(iter(new_node)) + assert second_item[TextStreamingDecoder.DATA_KEY] != first_item[TextStreamingDecoder.DATA_KEY] + + # Verify sleep was called twice during state restoration + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(1.0) # First retry delay + mock_sleep.assert_any_call(1.0) # Second retry delay + + finally: + for path in file_paths: + if os.path.exists(path): + os.remove(path) + os.rmdir(temp_dir) + + +def test_text_streaming_decoder_custom_max_retries(): + """Test TextStreamingDecoder with custom max_retries parameter.""" + file_paths = ["test.txt"] + source_node = MockSourceNode(file_paths) + + # Test default max_retries + node_default = TextStreamingDecoder(source_node) + assert node_default.max_retries == 3 + + # Test custom max_retries + node_custom = TextStreamingDecoder(source_node, max_retries=5) + assert node_custom.max_retries == 5 + + # Test zero retries + node_zero = TextStreamingDecoder(source_node, max_retries=0) + assert node_zero.max_retries == 0 + + +@patch("smart_open.open") +@patch("time.sleep") +def test_retry_logic_break_on_success(mock_sleep, mock_smart_open): + """Test that the retry loop breaks immediately on successful file opening.""" + # Mock smart_open to succeed on first attempt + mock_file = MagicMock() + mock_file.readline.side_effect = ["success_line\n", ""] + mock_smart_open.return_value = mock_file + + file_paths = ["s3://bucket/test.txt"] + source_node = MockSourceNode(file_paths) + node = TextStreamingDecoder(source_node, max_retries=3) + + # Should succeed immediately + content = next(iter(node)) + assert content[TextStreamingDecoder.DATA_KEY] == "success_line" + + # Verify sleep was never called (no retries needed) + mock_sleep.assert_not_called() + + # Verify smart_open was called exactly once + assert mock_smart_open.call_count == 1 diff --git a/torchdata/nodes/io/text_streaming_decoder.py b/torchdata/nodes/io/text_streaming_decoder.py index 532c5860b..b1f13752e 100644 --- a/torchdata/nodes/io/text_streaming_decoder.py +++ b/torchdata/nodes/io/text_streaming_decoder.py @@ -1,4 +1,5 @@ import logging +import time from typing import Any, Dict, Optional, Union import smart_open @@ -7,6 +8,27 @@ logger = logging.getLogger(__name__) +def _fibonacci_backoff(attempt: int, base_delay: float = 1.0) -> float: + """Calculate Fibonacci backoff delay for retry attempts. + + Args: + attempt: Current attempt number (1-based) + base_delay: Base delay in seconds + + Returns: + float: Delay in seconds before next retry + """ + if attempt <= 0: + return 0.0 + + # Fibonacci sequence: 1, 1, 2, 3, 5, 8, 13, 21, ... + fib_sequence = [1, 1] + for i in range(2, attempt + 1): + fib_sequence.append(fib_sequence[i - 1] + fib_sequence[i - 2]) + + return base_delay * fib_sequence[attempt - 1] + + class TextStreamingDecoder(BaseNode[Dict]): """Node that streams text files line by line from any source. @@ -19,6 +41,7 @@ class TextStreamingDecoder(BaseNode[Dict]): - Handles compressed files (.gz, .bz2) transparently - Maintains state for checkpointing and resumption - Preserves metadata from source nodes + - Automatic retry with Fibonacci backoff for file opening errors Input format: { @@ -41,10 +64,11 @@ class TextStreamingDecoder(BaseNode[Dict]): >>> # Stream from local file >>> node = TextStreamingDecoder(source_node) >>> - >>> # Stream from S3 with custom client + >>> # Stream from S3 with custom client and retry logic >>> node = TextStreamingDecoder( ... source_node, - ... transport_params={'client': boto3.client('s3')} + ... transport_params={'client': boto3.client('s3')}, + ... max_retries=5 ... ) >>> >>> # Stream compressed files @@ -66,6 +90,7 @@ def __init__( mode: str = "r", encoding: Optional[str] = "utf-8", transport_params: Optional[Dict] = None, + max_retries: int = 3, ): """Initialize the TextStreamingDecoder. @@ -80,12 +105,14 @@ def __init__( {'compression': '.gz'} # Force gzip compression {'compression': '.bz2'} # Force bz2 compression {'compression': 'disable'} # Disable compression + max_retries: Maximum number of retry attempts for file opening errors (default: 3) """ super().__init__() self.source = source_node self.mode = mode self.encoding = encoding self.transport_params = transport_params or {} + self.max_retries = max_retries self._current_file = None self._current_line = 0 self._file_handle = None @@ -119,12 +146,34 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): # If we have a file to resume, open and seek to position if self._current_file is not None: - self._file_handle = smart_open.open( - self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params - ) - # Skip lines to resume position - for _ in range(self._current_line): - next(self._file_handle) + # Retry logic for file opening during state restoration + for attempt in range(1, self.max_retries + 1): + try: + self._file_handle = smart_open.open( + self._current_file, + self.mode, + encoding=self.encoding, + transport_params=self.transport_params, + ) + # Skip lines to resume position + for _ in range(self._current_line): + next(self._file_handle) + break # Successfully opened and positioned + + except Exception as e: + if attempt < self.max_retries: + delay = _fibonacci_backoff(attempt) + logger.warning( + f"Error opening {self._current_file} during state restoration (attempt {attempt}/{self.max_retries}): {e}. Retrying in {delay:.2f}s..." + ) + time.sleep(delay) + else: + # Max retries reached, log error and continue without file handle + logger.error( + f"Failed to open {self._current_file} during state restoration after {self.max_retries} attempts. Last error: {e}" + ) + self._file_handle = None + break def __del__(self): """Ensure file is closed on deletion.""" @@ -154,17 +203,30 @@ def _get_next_file(self) -> bool: self._current_file = file_data self._source_metadata = {} - try: - # Open the file - self._file_handle = smart_open.open( - self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params - ) - self._current_line = 0 - return True - except Exception as e: - logger.error(f"Error opening {self._current_file}: {e}") - self._file_handle = None - return False # Failed to open file + # Retry logic for file opening + for attempt in range(1, self.max_retries + 1): + try: + # Try to open the file + self._file_handle = smart_open.open( + self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params + ) + self._current_line = 0 + return True + + except Exception as e: + if attempt < self.max_retries: + delay = _fibonacci_backoff(attempt) + logger.warning( + f"Error opening {self._current_file} (attempt {attempt}/{self.max_retries}): {e}. Retrying in {delay:.2f}s..." + ) + time.sleep(delay) + else: + # Max retries reached + logger.error( + f"Failed to open {self._current_file} after {self.max_retries} attempts. Last error: {e}" + ) + self._file_handle = None + return False # Failed to open file except StopIteration: # No more files From 2575995776e0e886fe8fecf523dc3735af0aefc6 Mon Sep 17 00:00:00 2001 From: keunwoochoi Date: Mon, 8 Sep 2025 10:40:38 -0400 Subject: [PATCH 5/9] address 8 mypy errors --- torchdata/nodes/io/text_streaming_decoder.py | 42 ++++++++++++++------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/torchdata/nodes/io/text_streaming_decoder.py b/torchdata/nodes/io/text_streaming_decoder.py index b1f13752e..596e832e7 100644 --- a/torchdata/nodes/io/text_streaming_decoder.py +++ b/torchdata/nodes/io/text_streaming_decoder.py @@ -2,7 +2,7 @@ import time from typing import Any, Dict, Optional, Union -import smart_open +import smart_open # type: ignore[import-untyped] from torchdata.nodes import BaseNode logger = logging.getLogger(__name__) @@ -89,7 +89,7 @@ def __init__( source_node: BaseNode[Union[str, Dict]], mode: str = "r", encoding: Optional[str] = "utf-8", - transport_params: Optional[Dict] = None, + transport_params: Optional[Dict[str, Any]] = None, max_retries: int = 3, ): """Initialize the TextStreamingDecoder. @@ -105,7 +105,7 @@ def __init__( {'compression': '.gz'} # Force gzip compression {'compression': '.bz2'} # Force bz2 compression {'compression': 'disable'} # Disable compression - max_retries: Maximum number of retry attempts for file opening errors (default: 3) + max_retries: Maximum number of retry attempts for any errors (default: 3) """ super().__init__() self.source = source_node @@ -113,10 +113,10 @@ def __init__( self.encoding = encoding self.transport_params = transport_params or {} self.max_retries = max_retries - self._current_file = None + self._current_file: Optional[str] = None self._current_line = 0 - self._file_handle = None - self._source_metadata = {} + self._file_handle: Optional[Any] = None + self._source_metadata: Dict[str, Any] = {} def reset(self, initial_state: Optional[Dict[str, Any]] = None): """Reset must fully initialize the node's state. @@ -195,13 +195,21 @@ def _get_next_file(self) -> bool: # Extract file path from data if isinstance(file_data, dict) and self.DATA_KEY in file_data: - self._current_file = file_data[self.DATA_KEY] + file_path = file_data[self.DATA_KEY] + if isinstance(file_path, str): + self._current_file = file_path + else: + logger.error(f"Invalid file path type: {type(file_path)}, expected str") + return False # Copy metadata from source if self.METADATA_KEY in file_data: self._source_metadata = file_data[self.METADATA_KEY] - else: + elif isinstance(file_data, str): self._current_file = file_data self._source_metadata = {} + else: + logger.error(f"Invalid file data type: {type(file_data)}") + return False # Retry logic for file opening for attempt in range(1, self.max_retries + 1): @@ -228,19 +236,29 @@ def _get_next_file(self) -> bool: self._file_handle = None return False # Failed to open file + # This should never be reached, but mypy needs it + return False + except StopIteration: - # No more files + # No more files - this should propagate up to stop iteration raise + except Exception as e: + # Any other unexpected error + logger.error(f"Unexpected error in _get_next_file: {e}") + return False - def _get_next_line(self) -> Dict: + def _get_next_line(self) -> Optional[Dict[str, Any]]: """Read the next line from the current file. Returns: - Dict: Dictionary with the line data and metadata. + Optional[Dict[str, Any]]: Dictionary with the line data and metadata, or None if end of file or error. Raises: StopIteration: If end of file is reached and no more files are available. """ + if self._file_handle is None: + return None + try: line = self._file_handle.readline() @@ -268,7 +286,7 @@ def _get_next_line(self) -> Dict: self._file_handle = None return None # Signal error - def next(self) -> Dict: + def next(self) -> Dict[str, Any]: """Get the next line from current file or next available file.""" # Loop until we get a valid line or run out of files while True: From 2e497336b16cc1839b1ec40c98af028edca1ff65 Mon Sep 17 00:00:00 2001 From: keunwoochoi Date: Mon, 8 Sep 2025 20:07:17 -0400 Subject: [PATCH 6/9] wip fixing --- torchdata/nodes/io/text_streaming_decoder.py | 47 +++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/torchdata/nodes/io/text_streaming_decoder.py b/torchdata/nodes/io/text_streaming_decoder.py index 596e832e7..5953e5422 100644 --- a/torchdata/nodes/io/text_streaming_decoder.py +++ b/torchdata/nodes/io/text_streaming_decoder.py @@ -147,30 +147,38 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): # If we have a file to resume, open and seek to position if self._current_file is not None: # Retry logic for file opening during state restoration - for attempt in range(1, self.max_retries + 1): + for attempt in range(0, self.max_retries + 1): try: - self._file_handle = smart_open.open( + cm = smart_open.open( self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params, ) - # Skip lines to resume position + # Prefer direct streaming handle when available + if hasattr(cm, "readline"): + self._file_handle = cm + elif hasattr(cm, "__enter__"): + self._file_handle = cm.__enter__() + else: + self._file_handle = cm + # Skip lines to resume position using streaming readline for _ in range(self._current_line): - next(self._file_handle) + _ = self._file_handle.readline() break # Successfully opened and positioned except Exception as e: - if attempt < self.max_retries: - delay = _fibonacci_backoff(attempt) + is_final = attempt >= self.max_retries + if not is_final: + delay = _fibonacci_backoff(attempt + 1) logger.warning( - f"Error opening {self._current_file} during state restoration (attempt {attempt}/{self.max_retries}): {e}. Retrying in {delay:.2f}s..." + f"Error opening {self._current_file} during state restoration (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay:.2f}s..." ) time.sleep(delay) else: # Max retries reached, log error and continue without file handle logger.error( - f"Failed to open {self._current_file} during state restoration after {self.max_retries} attempts. Last error: {e}" + f"Failed to open {self._current_file} during state restoration after {self.max_retries + 1} attempts. Last error: {e}" ) self._file_handle = None break @@ -212,33 +220,38 @@ def _get_next_file(self) -> bool: return False # Retry logic for file opening - for attempt in range(1, self.max_retries + 1): + for attempt in range(0, self.max_retries + 1): try: # Try to open the file - self._file_handle = smart_open.open( + cm = smart_open.open( self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params ) + # Prefer direct streaming handle when available + if hasattr(cm, "readline"): + self._file_handle = cm + elif hasattr(cm, "__enter__"): + self._file_handle = cm.__enter__() + else: + self._file_handle = cm self._current_line = 0 return True except Exception as e: - if attempt < self.max_retries: - delay = _fibonacci_backoff(attempt) + is_final = attempt >= self.max_retries + if not is_final: + delay = _fibonacci_backoff(attempt + 1) logger.warning( - f"Error opening {self._current_file} (attempt {attempt}/{self.max_retries}): {e}. Retrying in {delay:.2f}s..." + f"Error opening {self._current_file} (attempt {attempt + 1}/{self.max_retries + 1}): {e}. Retrying in {delay:.2f}s..." ) time.sleep(delay) else: # Max retries reached logger.error( - f"Failed to open {self._current_file} after {self.max_retries} attempts. Last error: {e}" + f"Failed to open {self._current_file} after {self.max_retries + 1} attempts. Last error: {e}" ) self._file_handle = None return False # Failed to open file - # This should never be reached, but mypy needs it - return False - except StopIteration: # No more files - this should propagate up to stop iteration raise From cdd608900cb28dab05202dbbfe98e0872f23e684 Mon Sep 17 00:00:00 2001 From: keunwoochoi Date: Wed, 10 Sep 2025 14:26:12 -0400 Subject: [PATCH 7/9] ok test passed. --- test/nodes/io/test_text_streaming_decoder.py | 99 +++++++++++++------- torchdata/nodes/io/text_streaming_decoder.py | 18 +--- 2 files changed, 68 insertions(+), 49 deletions(-) diff --git a/test/nodes/io/test_text_streaming_decoder.py b/test/nodes/io/test_text_streaming_decoder.py index 9a42e8369..b332d81ff 100644 --- a/test/nodes/io/test_text_streaming_decoder.py +++ b/test/nodes/io/test_text_streaming_decoder.py @@ -248,10 +248,15 @@ def test_text_stream_cleanup(): @patch("smart_open.open") def test_s3_basic_read(mock_smart_open): """Test basic S3 file reading with mocked smart_open.""" - # Mock smart_open for S3 + # Mock smart_open for S3 - set up context manager without readline attribute mock_file = MagicMock() mock_file.readline.side_effect = ['{"id": 1, "text": "Hello from S3"}\n', ""] - mock_smart_open.return_value.__enter__.return_value = mock_file + + # Set up mock context manager + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_file + mock_context.__exit__.return_value = None + mock_smart_open.return_value = mock_context file_paths = ["s3://test-bucket/test_file1.jsonl"] source_node = MockSourceNode(file_paths, {"source": "s3"}) @@ -274,10 +279,14 @@ def test_s3_basic_read(mock_smart_open): @patch("smart_open.open") def test_compression_handling(mock_smart_open): """Test compressed file handling.""" - # Mock smart_open for compressed file + # Mock smart_open for compressed file - set up context manager without readline mock_file = MagicMock() mock_file.readline.side_effect = ["decompressed_line1\n", "decompressed_line2\n", ""] - mock_smart_open.return_value.__enter__.return_value = mock_file + + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_file + mock_context.__exit__.return_value = None + mock_smart_open.return_value = mock_context file_paths = ["s3://bucket/compressed.txt.gz"] source_node = MockSourceNode(file_paths, {"source": "s3"}) @@ -386,10 +395,14 @@ def test_text_stream_recursive_behavior(): @patch("smart_open.open") def test_azure_gcs_support(mock_smart_open): """Test Azure and GCS support via smart_open.""" - # Test Azure + # Test Azure - set up context manager without readline mock_file = MagicMock() mock_file.readline.side_effect = ["azure_content\n", ""] - mock_smart_open.return_value.__enter__.return_value = mock_file + + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_file + mock_context.__exit__.return_value = None + mock_smart_open.return_value = mock_context azure_paths = ["abfs://container@account.dfs.core.windows.net/file.txt"] source_node = MockSourceNode(azure_paths, {"source": "abfs"}) @@ -399,7 +412,7 @@ def test_azure_gcs_support(mock_smart_open): assert item[TextStreamingDecoder.DATA_KEY] == "azure_content" assert item[TextStreamingDecoder.METADATA_KEY]["source"] == "abfs" - # Test GCS + # Test GCS - reset mock file for new content mock_file.readline.side_effect = ["gcs_content\n", ""] gcs_paths = ["gs://my-bucket/file.txt"] source_node = MockSourceNode(gcs_paths, {"source": "gs"}) @@ -437,10 +450,17 @@ def test_retry_logic_success_after_failure(mock_sleep, mock_smart_open): mock_file = MagicMock() mock_file.readline.side_effect = ["success_line\n", ""] + # Set up successful context manager for third attempt + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_file + mock_context.__exit__.return_value = None + # Explicitly delete readline from context manager to force __enter__ path + del mock_context.readline + mock_smart_open.side_effect = [ Exception("Connection timeout"), # First attempt fails Exception("Network error"), # Second attempt fails - mock_file, # Third attempt succeeds + mock_context, # Third attempt succeeds ] file_paths = ["s3://bucket/test.txt"] @@ -504,42 +524,46 @@ def test_retry_logic_state_restoration(mock_sleep, mock_smart_open): """Test retry logic during state restoration.""" # Mock smart_open to fail twice then succeed during state restoration mock_file = MagicMock() - mock_file.readline.side_effect = ["resumed_line\n", ""] + # First readline call for skipping to position, then actual content + mock_file.readline.side_effect = ["", "resumed_line\n", ""] + + # Set up successful context manager for third attempt + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_file + mock_context.__exit__.return_value = None + # Explicitly delete readline from context manager to force __enter__ path + del mock_context.readline mock_smart_open.side_effect = [ Exception("Connection timeout"), # First attempt fails Exception("Network error"), # Second attempt fails - mock_file, # Third attempt succeeds + mock_context, # Third attempt succeeds ] - temp_dir, file_paths = create_test_files() - try: - source_node = MockSourceNode(file_paths) - node = TextStreamingDecoder(source_node, max_retries=3) + # Use mock file paths instead of real files to avoid conflicts + file_paths = ["mock://file1.txt"] - # Read first line and store state - first_item = next(iter(node)) - state = node.get_state() + # Create a mock state that simulates having read one line already + mock_source_state = {"idx": 1} + state = { + TextStreamingDecoder.SOURCE_KEY: mock_source_state, + TextStreamingDecoder.CURRENT_FILE_KEY: file_paths[0], + TextStreamingDecoder.CURRENT_LINE_KEY: 1, # Simulate having read one line + } - # Create new node and restore state (this will trigger retry logic) - new_source = MockSourceNode(file_paths) - new_node = TextStreamingDecoder(new_source, max_retries=3) - new_node.reset(state) + # Create new node and restore state (this will trigger retry logic) + new_source = MockSourceNode(file_paths) + new_node = TextStreamingDecoder(new_source, max_retries=3) + new_node.reset(state) - # Read next line - should succeed after retries - second_item = next(iter(new_node)) - assert second_item[TextStreamingDecoder.DATA_KEY] != first_item[TextStreamingDecoder.DATA_KEY] - - # Verify sleep was called twice during state restoration - assert mock_sleep.call_count == 2 - mock_sleep.assert_any_call(1.0) # First retry delay - mock_sleep.assert_any_call(1.0) # Second retry delay + # Read next line - should succeed after retries + second_item = next(iter(new_node)) + assert second_item[TextStreamingDecoder.DATA_KEY] == "resumed_line" - finally: - for path in file_paths: - if os.path.exists(path): - os.remove(path) - os.rmdir(temp_dir) + # Verify sleep was called twice during state restoration + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(1.0) # First retry delay + mock_sleep.assert_any_call(1.0) # Second retry delay def test_text_streaming_decoder_custom_max_retries(): @@ -567,7 +591,12 @@ def test_retry_logic_break_on_success(mock_sleep, mock_smart_open): # Mock smart_open to succeed on first attempt mock_file = MagicMock() mock_file.readline.side_effect = ["success_line\n", ""] - mock_smart_open.return_value = mock_file + + # Set up successful context manager + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_file + mock_context.__exit__.return_value = None + mock_smart_open.return_value = mock_context file_paths = ["s3://bucket/test.txt"] source_node = MockSourceNode(file_paths) diff --git a/torchdata/nodes/io/text_streaming_decoder.py b/torchdata/nodes/io/text_streaming_decoder.py index 5953e5422..4266d55df 100644 --- a/torchdata/nodes/io/text_streaming_decoder.py +++ b/torchdata/nodes/io/text_streaming_decoder.py @@ -155,13 +155,8 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): encoding=self.encoding, transport_params=self.transport_params, ) - # Prefer direct streaming handle when available - if hasattr(cm, "readline"): - self._file_handle = cm - elif hasattr(cm, "__enter__"): - self._file_handle = cm.__enter__() - else: - self._file_handle = cm + # smart_open returns a context manager - enter it to get file handle + self._file_handle = cm.__enter__() # Skip lines to resume position using streaming readline for _ in range(self._current_line): _ = self._file_handle.readline() @@ -226,13 +221,8 @@ def _get_next_file(self) -> bool: cm = smart_open.open( self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params ) - # Prefer direct streaming handle when available - if hasattr(cm, "readline"): - self._file_handle = cm - elif hasattr(cm, "__enter__"): - self._file_handle = cm.__enter__() - else: - self._file_handle = cm + # smart_open returns a context manager - enter it to get file handle + self._file_handle = cm.__enter__() self._current_line = 0 return True From 50144911e819cd976cdb0ac14ff15dad24cf05e0 Mon Sep 17 00:00:00 2001 From: keunwoochoi Date: Wed, 10 Sep 2025 16:29:27 -0400 Subject: [PATCH 8/9] explicit false --- torchdata/nodes/io/text_streaming_decoder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchdata/nodes/io/text_streaming_decoder.py b/torchdata/nodes/io/text_streaming_decoder.py index 4266d55df..507c9eabf 100644 --- a/torchdata/nodes/io/text_streaming_decoder.py +++ b/torchdata/nodes/io/text_streaming_decoder.py @@ -242,6 +242,9 @@ def _get_next_file(self) -> bool: self._file_handle = None return False # Failed to open file + # If we get here, all retry attempts failed + return False + except StopIteration: # No more files - this should propagate up to stop iteration raise From 8832f44d84c0ec91ea2da34df4e1cb62908b08ae Mon Sep 17 00:00:00 2001 From: keunwoochoi Date: Wed, 10 Sep 2025 16:47:50 -0400 Subject: [PATCH 9/9] close file explicitly for windows --- torchdata/nodes/io/text_streaming_decoder.py | 30 ++++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/torchdata/nodes/io/text_streaming_decoder.py b/torchdata/nodes/io/text_streaming_decoder.py index 507c9eabf..89eb78860 100644 --- a/torchdata/nodes/io/text_streaming_decoder.py +++ b/torchdata/nodes/io/text_streaming_decoder.py @@ -116,8 +116,19 @@ def __init__( self._current_file: Optional[str] = None self._current_line = 0 self._file_handle: Optional[Any] = None + self._context_manager: Optional[Any] = None # Store context manager for proper cleanup self._source_metadata: Dict[str, Any] = {} + def _close_current_file(self): + """Close the current file and context manager properly.""" + if self._context_manager is not None: + try: + self._context_manager.__exit__(None, None, None) + except Exception: + pass # Ignore errors during cleanup + self._context_manager = None + self._file_handle = None + def reset(self, initial_state: Optional[Dict[str, Any]] = None): """Reset must fully initialize the node's state. @@ -127,9 +138,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): super().reset(initial_state) # Close any open file - if self._file_handle is not None: - self._file_handle.close() - self._file_handle = None + self._close_current_file() if initial_state is None: # Full reset @@ -156,6 +165,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): transport_params=self.transport_params, ) # smart_open returns a context manager - enter it to get file handle + self._context_manager = cm self._file_handle = cm.__enter__() # Skip lines to resume position using streaming readline for _ in range(self._current_line): @@ -222,6 +232,7 @@ def _get_next_file(self) -> bool: self._current_file, self.mode, encoding=self.encoding, transport_params=self.transport_params ) # smart_open returns a context manager - enter it to get file handle + self._context_manager = cm self._file_handle = cm.__enter__() self._current_line = 0 return True @@ -239,7 +250,7 @@ def _get_next_file(self) -> bool: logger.error( f"Failed to open {self._current_file} after {self.max_retries + 1} attempts. Last error: {e}" ) - self._file_handle = None + self._close_current_file() return False # Failed to open file # If we get here, all retry attempts failed @@ -270,8 +281,7 @@ def _get_next_line(self) -> Optional[Dict[str, Any]]: # EOF or empty line at end of file if not line: - self._file_handle.close() - self._file_handle = None + self._close_current_file() return None # Signal end of file # Create output with metadata @@ -287,9 +297,7 @@ def _get_next_line(self) -> Optional[Dict[str, Any]]: except Exception as e: logger.error(f"Error reading from {self._current_file}: {e}") - if self._file_handle: - self._file_handle.close() - self._file_handle = None + self._close_current_file() return None # Signal error def next(self) -> Dict[str, Any]: @@ -321,6 +329,4 @@ def get_state(self) -> Dict[str, Any]: def shutdown(self): """Shutdown the node.""" - if self._file_handle is not None: - self._file_handle.close() - self._file_handle = None + self._close_current_file()