Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@ tensorswitch/
├── src
│ └── tensorswitch
│ ├── __init__.py
│ ├── __main__.py # Main dispatcher script
│ ├── __main__.py # Main dispatcher script
│ ├── tasks
│ │ ├── __init__.py
│ │ ├── downsample_shard_zarr3.py # Downsample using shards
│ │ ├── n5_to_n5.py # N5 to N5 conversion logic
│ │ ├── n5_to_zarr2.py # N5 to Zarr V2 conversion logic
│ │ ├── tiff_to_zarr3_s0.py # TIFF to Zarr V3 level s0 conversion logic
│ ├── utils.py # Common utilities (chunk domain calculation)
│ ├── z_to_chunk_index.py # Print chunk index ranges for resubmit failed or left over jobs
├── re_submit_jobs.ipynb # Jupyter notebook to re-submit failed chunk jobs
├── contrib
│ ├── re_submit_jobs.ipynb # Jupyter notebook to re-submit failed chunk jobs
│ ├── start_neuroglancer_server.py # Start a CORS-enabled web server
│ └── z_to_chunk_index.py # Print chunk index ranges for resubmit failed or left over jobs
└── tests
├── test_n5_to_n5.py
├── test_n5_to_zarr2.py
Expand Down
29 changes: 29 additions & 0 deletions contrib/start_neuroglancer_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import http.server
import socketserver
import os

PORT = 8866

class CORSHandler(http.server.SimpleHTTPRequestHandler):
def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'X-Requested-With, Content-Type')
super().end_headers()

def do_OPTIONS(self):
self.send_response(200)
self.end_headers()

if __name__ == "__main__":
# Change to the directory where your data is located
# For example, if your data is in 'test_for_AhrensLab/output_zarr3/'
# you might want to change to that directory or a parent directory.
# For now, I'll assume you want to serve from the project root.
# You can modify this path if your data is elsewhere.
os.chdir('test_for_AhrensLab/output_zarr3/') # Uncomment and modify if needed

with socketserver.TCPServer(("", PORT), CORSHandler) as httpd:
print(f"Serving Neuroglancer data at http://localhost:{PORT}")
print("Press Ctrl+C to stop the server.")
httpd.serve_forever()
File renamed without changes.
4 changes: 2 additions & 2 deletions src/tensorswitch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
sys.path.insert(0, package_source_path)

from . import tasks
from .utils import get_total_chunks, downsample_spec, zarr3_store_spec, get_chunk_domains, estimate_total_chunks_for_tiff, get_input_driver
from .utils import get_total_chunks, downsample_spec, zarr3_store_spec, get_chunk_domains, estimate_total_chunks_for_tiff, get_input_driver, get_total_chunks_from_store
from .tasks import downsample_shard_zarr3
from .tasks import n5_to_n5
from .tasks import n5_to_zarr2
Expand Down Expand Up @@ -73,7 +73,7 @@ def submit_job(args):
# Count output chunks
# total_chunks = get_total_chunks(downsampled_saved_spec)
chunk_shape = downsample_store.chunk_layout.read_chunk.shape
total_chunks = len(get_chunk_domains(chunk_shape, downsample_store))
total_chunks = get_total_chunks_from_store(downsample_store, chunk_shape=chunk_shape)

else:
total_chunks = get_total_chunks(args.base_path)
Expand Down
25 changes: 7 additions & 18 deletions src/tensorswitch/tasks/downsample_shard_zarr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@
import numpy as np
import time
import psutil
from ..utils import get_chunk_domains, create_output_store, commit_tasks, print_processing_info, downsample_spec, zarr3_store_spec, get_input_driver
from ..utils import get_chunk_domains, create_output_store, commit_tasks, print_processing_info, downsample_spec, zarr3_store_spec, get_input_driver, get_total_chunks_from_store
import os

def process(base_path, output_path, level, start_idx=0, stop_idx=None, downsample=True, use_shard=True, memory_limit=50, **kwargs):

"""Downsample and optionally apply sharding to Zarr3 dataset."""
'''
if level == 0:
zarr_input_path = base_path
else:
zarr_input_path = f"{base_path}/multiscale/s{level-1}"
'''
if base_path.endswith(f"s{level - 1}") or level == 0:
zarr_input_path = base_path
else:
Expand All @@ -34,12 +28,6 @@ def process(base_path, output_path, level, start_idx=0, stop_idx=None, downsampl
print(f"Reading from: {zarr_input_path}")
print(f"Writing to: {downsampled_saved_path}")

'''
zarr_store_spec = {
'driver': 'zarr' + ('3' if level > 0 else ''),
'kvstore': {'driver': 'file', 'path': zarr_input_path}
}
'''
zarr_store = ts.open(zarr_store_spec).result()

if downsample and level > 0:
Expand All @@ -63,17 +51,18 @@ def process(base_path, output_path, level, start_idx=0, stop_idx=None, downsampl
chunk_shape = downsample_store.chunk_layout.read_chunk.shape
print("Shape of downsample_store:", downsample_store.shape)
print("Chunk shape used:", chunk_shape)
chunk_domains = get_chunk_domains(chunk_shape, downsample_store) # compute chunk domains based on the downsampled input when goes from s0 to s1

# compute chunk domains based on the downsampled input when goes from s0 to s1
total_chunks = get_total_chunks_from_store(downsample_store, chunk_shape=chunk_shape)

if stop_idx is None:
stop_idx = len(chunk_domains)
stop_idx = total_chunks

print_processing_info(level, start_idx, stop_idx, len(chunk_domains))
print_processing_info(level, start_idx, stop_idx, total_chunks)

tasks = []
txn = ts.Transaction()
for chunk_domain in chunk_domains[start_idx:stop_idx]:
linear_indices_to_process = range(start_idx, stop_idx)
for chunk_domain in get_chunk_domains(chunk_shape, downsample_store, linear_indices_to_process=linear_indices_to_process):
task = downsampled_saved[chunk_domain].with_transaction(txn).write(downsample_store[chunk_domain])
tasks.append(task)
txn = commit_tasks(tasks, txn, memory_limit)
Expand Down
30 changes: 10 additions & 20 deletions src/tensorswitch/tasks/n5_to_n5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
import psutil
from ..utils import get_chunk_domains, n5_store_spec, create_output_store, commit_tasks, print_processing_info, fetch_http_json
from ..utils import get_chunk_domains, n5_store_spec, create_output_store, commit_tasks, print_processing_info, fetch_http_json, get_total_chunks_from_store

def convert(base_path, output_path, number, level, start_idx=0, stop_idx=None, memory_limit=50, **kwargs):
"""Convert N5 to N5 format."""
Expand Down Expand Up @@ -41,35 +41,25 @@ def convert(base_path, output_path, number, level, start_idx=0, stop_idx=None, m
}

n5_output_store = create_output_store(n5_output_spec)
chunk_domains = get_chunk_domains(chunk_shape, n5_output_store)
print(f"Generated {len(chunk_domains)} chunk domains")
total_chunks = get_total_chunks_from_store(n5_output_store, chunk_shape=chunk_shape)
print(f"Generated {total_chunks} chunk domains")
print(f"Start index: {start_idx}, Stop index: {stop_idx}")
print("First few chunk domains:", chunk_domains[:3])

print("Reading from source N5...")
sample = n5_store[chunk_domains[0]].read().result()
print("Sample shape:", sample.shape)

sample_chunk_domain_for_print = next(iter(get_chunk_domains(chunk_shape, n5_output_store)))
sample = n5_store[sample_chunk_domain_for_print].read().result()
print("Sample shape:", sample.shape)

if stop_idx is None:
stop_idx = len(chunk_domains)
stop_idx = total_chunks

print_processing_info(level, start_idx, stop_idx, len(chunk_domains))
print_processing_info(level, start_idx, stop_idx, total_chunks)

tasks = []
txn = ts.Transaction()
"""
for chunk_domain in chunk_domains[start_idx:stop_idx]:
#task = n5_output_store[chunk_domain].with_transaction(txn).write(n5_store[chunk_domain])
array = n5_store[chunk_domain].read().result()
task = n5_output_store[chunk_domain].with_transaction(txn).write(array)
tasks.append(task)
txn = commit_tasks(tasks, txn, memory_limit)
print(f"Writing chunk: {chunk_domain}, array shape: {array.shape}")

"""

for idx, chunk_domain in enumerate(chunk_domains[start_idx:stop_idx], start=start_idx):
linear_indices_to_process = range(start_idx, stop_idx)
for idx, chunk_domain in enumerate(get_chunk_domains(chunk_shape, n5_output_store, linear_indices_to_process=linear_indices_to_process), start=start_idx):
try:
array = n5_store[chunk_domain].read().result()
except Exception as e:
Expand Down
13 changes: 7 additions & 6 deletions src/tensorswitch/tasks/n5_to_zarr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import time
import psutil
from ..utils import get_chunk_domains, n5_store_spec, zarr2_store_spec, create_output_store, commit_tasks, print_processing_info
from ..utils import get_chunk_domains, n5_store_spec, zarr2_store_spec, create_output_store, commit_tasks, print_processing_info, get_total_chunks_from_store

def convert(base_path, output_path, level, start_idx=0, stop_idx=None, memory_limit=50, **kwargs):
"""Convert N5 to Zarr2 format."""
Expand All @@ -18,19 +18,20 @@ def convert(base_path, output_path, level, start_idx=0, stop_idx=None, memory_li
zarr2_spec = zarr2_store_spec(zarr_level_path, shape, chunks)
zarr2_store = create_output_store(zarr2_spec)

chunk_domains = get_chunk_domains(chunks, zarr2_store)
print(f" Total chunks to write: {len(chunk_domains)}")
total_chunks = get_total_chunks_from_store(zarr2_store, chunk_shape=chunks)
print(f" Total chunks to write: {total_chunks}")
print(f" Writing from chunk {start_idx} to {stop_idx}")


if stop_idx is None:
stop_idx = len(chunk_domains)
stop_idx = total_chunks

print_processing_info(level, start_idx, stop_idx, len(chunk_domains))
print_processing_info(level, start_idx, stop_idx, total_chunks)

tasks = []
txn = ts.Transaction()
for chunk_domain in chunk_domains[start_idx:stop_idx]:
linear_indices_to_process = range(start_idx, stop_idx)
for chunk_domain in get_chunk_domains(chunks, zarr2_store, linear_indices_to_process=linear_indices_to_process):
task = zarr2_store[chunk_domain].with_transaction(txn).write(n5_store[chunk_domain])
tasks.append(task)
txn = commit_tasks(tasks, txn, memory_limit)
Expand Down
38 changes: 25 additions & 13 deletions src/tensorswitch/tasks/tiff_to_zarr3_s0.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dask.cache import Cache
from ..utils import load_tiff_stack, zarr3_store_spec, get_chunk_domains, commit_tasks
from ..utils import load_tiff_stack, zarr3_store_spec, get_chunk_domains, commit_tasks, get_total_chunks_from_store
import tensorstore as ts
import numpy as np
import psutil
Expand All @@ -13,7 +13,20 @@ def process(base_path, output_path, use_shard=False, memory_limit=50, start_idx=
print(f"Loading TIFF stack from: {base_path}", flush=True)

volume = load_tiff_stack(base_path)
print(f"Volume shape: {volume.shape}, dtype: {volume.dtype}", flush=True)
print(f"Original volume shape: {volume.shape}, dtype: {volume.dtype}", flush=True)
print(f"Original chunk structure from dask: {volume.chunksize}", flush=True)

# DEBUG
print(f"Volume dimensions: {len(volume.shape)}D")
print(f"Volume chunk structure from dask: {volume.chunksize}")

# DEBUG: what a single chunk looks like
if len(volume.shape) == 4:
print("4D array detected - likely (C, Z, Y, X)")
print(f"Channels: {volume.shape[0]}")
print(f"Z-slices: {volume.shape[1]}")
print(f"Y (height): {volume.shape[2]}")
print(f"X (width): {volume.shape[3]}")

# Enable Dask opportunistic cache with 8 GB RAM
cache = Cache(8 * 1024**3) # 8 GiB = 8 × 1024³ = 8,589,934,592 bytes
Expand All @@ -31,23 +44,22 @@ def process(base_path, output_path, use_shard=False, memory_limit=50, start_idx=

# Prepare chunk domains and filter to assigned range
chunk_shape = store.chunk_layout.write_chunk.shape
chunk_domains = get_chunk_domains(chunk_shape, store)
chunk_domains = chunk_domains[start_idx:stop_idx] if stop_idx is not None else chunk_domains[start_idx:]
total_chunks = get_total_chunks_from_store(store, chunk_shape=chunk_shape)
print(f"Total chunks: {total_chunks}")
linear_indices_to_process = range(start_idx, stop_idx or total_chunks)
chunk_domains = get_chunk_domains(chunk_shape, store, linear_indices_to_process=linear_indices_to_process)

print(f"Processing {len(chunk_domains)} chunks: start={start_idx}, stop={stop_idx}", flush=True)
print(f"Processing {len(linear_indices_to_process)} chunks: start={start_idx}, stop={stop_idx}", flush=True)

tasks = []
ntasks = 0
txn = ts.Transaction()

for domain in chunk_domains:
task = store[domain].with_transaction(txn).write(
volume[
domain.inclusive_min[0]:domain.exclusive_max[0],
domain.inclusive_min[1]:domain.exclusive_max[1],
domain.inclusive_min[2]:domain.exclusive_max[2],
].compute()
)
# Handle both 3D and 4D arrays dynamically
slices = tuple(slice(min, max) for (min,max) in zip(domain.inclusive_min, domain.exclusive_max))
slice_data = volume[slices]
task = store[domain].with_transaction(txn).write(slice_data.compute())

tasks.append(task)
ntasks += 1
Expand All @@ -63,4 +75,4 @@ def process(base_path, output_path, use_shard=False, memory_limit=50, start_idx=
task.result()

txn.commit_sync()
print(f"Completed writing Zarr3 s0 at: {output_path} [{start_idx}:{stop_idx}]", flush=True)
print(f"Completed writing Zarr3 s0 at: {output_path} [{start_idx}:{stop_idx}]", flush=True)
Loading