Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ jobs:
continue-on-error: true
steps:
- uses: actions/checkout@v3
- uses: astral-sh/ruff-action@v1
- uses: astral-sh/ruff-action@v3
with:
args: check --verbose
7 changes: 7 additions & 0 deletions examples/visualize_mapgen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# README

This a file which aides visualization of the geometric objects in mapgen.

So far there is:

* [`make_cart_rect`](https://github.com/djps/k-wave-python/blob/mapgen_plotting/kwave/utils/mapgen.py#L2591)
820 changes: 820 additions & 0 deletions examples/visualize_mapgen/visualize_mapgen.ipynb

Large diffs are not rendered by default.

152 changes: 102 additions & 50 deletions kwave/utils/mapgen.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,34 @@
import logging
import math
import warnings
from math import floor
from math import ceil, floor

import matplotlib.pyplot as plt
import numpy as np
import scipy
from beartype import beartype as typechecker
from beartype.typing import List, Optional, Tuple, Union, cast
from beartype.typing import List, Literal, Optional, Tuple, Union, cast
from jaxtyping import Complex, Float, Int, Integer, Real
from scipy import optimize
from scipy.optimize import fmin
from scipy.spatial.transform import Rotation
from scipy.special import jv

import kwave.utils.typing as kt
from kwave.utils.math import compute_linear_transform, compute_rotation_between_vectors

from ..data import Vector
from .conversion import db2neper, neper2db
from .data import scale_SI
from .math import Rx, Ry, Rz, compute_linear_transform, cosd, sind
from .matlab import ind2sub, matlab_assign, matlab_find, sub2ind
from .matrix import max_nd
from .tictoc import TicToc
from kwave.data import Vector
from kwave.utils.conversion import db2neper, neper2db
from kwave.utils.data import scale_SI
from kwave.utils.math import compute_linear_transform, cosd, sind
from kwave.utils.matlab import ind2sub, matlab_assign, matlab_find, sub2ind
from kwave.utils.matrix import max_nd
from kwave.utils.tictoc import TicToc

# GLOBALS
# define literals (ref: http://www.wolframalpha.com/input/?i=golden+angle)
GOLDEN_ANGLE = 2.39996322972865332223155550663361385312499901105811504
PACKING_NUMBER = 7 # 2*pi


@typechecker
def make_cart_disc(
disc_pos: np.ndarray, radius: float, focus_pos: np.ndarray, num_points: int, plot_disc: bool = False, use_spiral: bool = False
disc_pos: np.ndarray, radius: float, focus_pos: np.ndarray, num_points: int,
plot_disc: Optional[Union[bool, Literal[0, 1]]] = False, use_spiral: Optional[Union[bool, Literal[0, 1]]] = False
) -> np.ndarray:
"""
Create evenly distributed Cartesian points covering a disc.
Expand Down Expand Up @@ -113,7 +111,7 @@ def make_concentric_circle_points(num_points: int, radius: float) -> Tuple[np.nd
# specified disc
if len(disc_pos) == 3:
# check the focus position isn't coincident with the disc position
if all(disc_pos == focus_pos):
if np.all(np.isclose(np.array(disc_pos), np.array(focus_pos))):
raise ValueError("The focus_pos must be different from the disc_pos.")

# compute rotation matrix and apply
Expand All @@ -129,17 +127,23 @@ def make_concentric_circle_points(num_points: int, radius: float) -> Tuple[np.nd
_, scale, prefix, unit = scale_SI(np.max(disc))

# create the figure
fig = plt.figure()
cmap = plt.get_cmap('viridis', np.shape(disc)[1])

if len(disc_pos) == 2:
plt.figure()
plt.plot(disc[1, :] * scale, disc[0, :] * scale, ".")
plt.gca().invert_yaxis()
plt.xlabel(f"y-position [{prefix}m]")
plt.ylabel(f"x-position [{prefix}m]")
plt.axis("equal")
ax = fig.add_subplot(111)
ax.scatter(disc[1, :] * scale, disc[0, :] * scale, marker='.',
c=np.arange(np.shape(disc)[1]), cmap=cmap, alpha=0.9, edgecolor=None)
ax.invert_yaxis()
ax.xlabel(f"y-position [{prefix}m]")
ax.ylabel(f"x-position [{prefix}m]")
ax.axis("equal")
ax.grid(True)
ax.box(True)
else:
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.plot3D(disc[0, :] * scale, disc[1, :] * scale, disc[2, :] * scale, ".")
ax.scatter(disc[0, :] * scale, disc[1, :] * scale, disc[2, :] * scale, marker='.',
c=np.arange(np.shape(disc)[1]), cmap=cmap, alpha=0.9, edgecolor=None)
ax.set_xlabel(f"[{prefix}m]")
ax.set_ylabel(f"[{prefix}m]")
ax.set_zlabel(f"[{prefix}m]")
Expand Down Expand Up @@ -188,9 +192,8 @@ def make_cart_bowl(

# check for infinite radius of curvature, and call makeCartDisc instead
if np.isinf(radius):
# bowl = make_cart_disc(bowl_pos, diameter / 2, focus_pos, num_points, plot_bowl)
# return bowl
raise NotImplementedError("make_cart_disc")
bowl = make_cart_disc(bowl_pos, diameter / 2, focus_pos, num_points, plot_bowl)
return bowl

# compute arc angle from chord (ref: https://en.wikipedia.org/wiki/Chord_(geometry))
varphi_max = np.arcsin(diameter / (2 * radius))
Expand All @@ -217,11 +220,12 @@ def make_cart_bowl(
if plot_bowl is True:
# select suitable axis scaling factor
_, scale, prefix, unit = scale_SI(np.max(bowl))

# create the figure
fig = plt.figure()
cmap = plt.get_cmap('viridis', np.shape(bowl)[1])
ax = fig.add_subplot(111, projection="3d")
ax.scatter(bowl[0, :] * scale, bowl[1, :] * scale, bowl[2, :] * scale)
ax.scatter(bowl[0, :] * scale, bowl[1, :] * scale, bowl[2, :] * scale, marker='.',
c=np.arange(np.shape(bowl)[1]), cmap=cmap, alpha=0.9, edgecolor=None)
ax.set_xlabel("[" + prefix + unit + "]")
ax.set_ylabel("[" + prefix + unit + "]")
ax.set_zlabel("[" + prefix + unit + "]")
Expand Down Expand Up @@ -315,7 +319,7 @@ def abs_func(trial_vals):

return absorption_error

a0_np_fit, y_fit = optimize.fmin(abs_func, [a0_np, y])
a0_np_fit, y_fit = fmin(abs_func, [a0_np, y])

a0_fit = neper2db(a0_np_fit, y_fit)

Expand Down Expand Up @@ -549,17 +553,28 @@ def make_ball(

# plot results
if plot_ball:
raise NotImplementedError
# voxelPlot(double(ball))
_, scale, prefix, _ = scale_SI(np.max(ball))
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.scatter(ball[0] * scale, ball[1] * scale, ball[2] * scale)
ax.set_xlabel("[" + prefix + "m]")
ax.set_ylabel("[" + prefix + "m]")
ax.set_zlabel("[" + prefix + "m]")
ax.set_box_aspect([1, 1, 1])
plt.grid(True)
plt.show()
return ball


@typechecker
def make_cart_sphere(
radius: Union[float, int], num_points: int, center_pos: Vector = Vector([0, 0, 0]), plot_sphere: bool = False
radius: Union[float, int],
num_points: int,
center_pos: Optional[Union[Real[kt.ScalarLike, "3"], List, Int[np.ndarray, "3"], Float[np.ndarray, "3"]]] = np.zeros((3,)),
plot_sphere: bool = False
) -> Float[np.ndarray, "3 NumPoints"]:
"""
Cart_sphere creates a set of points in Cartesian coordinates defining a sphere.
Creates a set of points in Cartesian coordinates defining a sphere.

Args:
radius: the radius of the sphere.
Expand Down Expand Up @@ -592,12 +607,15 @@ def make_cart_sphere(
if plot_sphere:
# select suitable axis scaling factor
[x_sc, scale, prefix, _] = scale_SI(np.max(sphere))


cmap = plt.get_cmap('viridis', np.shape(sphere)[1])

# create the figure
plt.figure()
plt.style.use("seaborn-poster")
# plt.style.use("seaborn-poster")
ax = plt.axes(projection="3d")
ax.plot3D(sphere[0, :] * scale, sphere[1, :] * scale, sphere[2, :] * scale, ".")
ax.scatter(sphere[0, :] * scale, sphere[1, :] * scale, sphere[2, :] * scale, marker='.',
c=np.arange(np.shape(sphere)[1]), cmap=cmap, alpha=0.9, edgecolor=None)
ax.set_xlabel(f"[{prefix} m]")
ax.set_ylabel(f"[{prefix} m]")
ax.set_zlabel(f"[{prefix} m]")
Expand All @@ -607,7 +625,7 @@ def make_cart_sphere(

return sphere.squeeze()


@typechecker
def make_cart_circle(
radius: float, num_points: int, center_pos: Vector = Vector([0, 0]), arc_angle: float = 2 * np.pi, plot_circle: bool = False
) -> Float[np.ndarray, "2 NumPoints"]:
Expand Down Expand Up @@ -715,7 +733,12 @@ def make_disc(grid_size: Vector, center: Vector, radius, plot_disc=False) -> kt.

# create the figure
if plot_disc:
raise NotImplementedError
_, ax = plt.subplots(1, 1)
ax.imshow(disc)
ax.set_aspect('auto', adjustable='box')
ax.yaxis.set_inverted(True)
plt.show()

return disc


Expand Down Expand Up @@ -2333,8 +2356,8 @@ def make_sphere(
and optional flags to plot the sphere and/or return a binary mask.

Args:
grid_size: The size of the grid (assumed to be cubic).
radius: The radius of the sphere.
grid_size: The size of the grid (assumed to be three dimensional).
radius: The radius of the sphere in grid points
plot_sphere: Whether to plot the sphere (default: False).
binary: Whether to return a binary mask (default: False).

Expand Down Expand Up @@ -2415,7 +2438,15 @@ def make_sphere(

# plot results
if plot_sphere:
raise NotImplementedError
# create the figure: this is a binary mask.
fig = plt.figure()
cmap = plt.get_cmap('viridis', np.shape(sphere)[0])
ax = fig.add_subplot(111, projection="3d")
ax.scatter(sphere[0], sphere[1], sphere[2], marker='s', c=np.arange(np.shape(sphere)[0]), cmap=cmap, alpha=0.9, edgecolor=None)
ax.set_box_aspect([1, 1, 1])
plt.grid(True)
plt.show()

return sphere


Expand Down Expand Up @@ -2444,7 +2475,6 @@ def make_spherical_section(

Raises:
ValueError: If the width is not an odd number.
NotImplementedError: Plotting not currently supported.
"""
use_spherical_sections = True

Expand Down Expand Up @@ -2557,7 +2587,13 @@ def make_spherical_section(

# plot if required
if plot_section:
raise NotImplementedError
# create the figure
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.scatter(ss[0], ss[1], ss[2], marker='s', c='black', alpha=0.9)
ax.set_box_aspect([1, 1, 1])
plt.grid(True)
plt.show()

return ss, dist_map

Expand Down Expand Up @@ -2587,8 +2623,8 @@ def make_cart_rect(
"""

# Find number of points in along each axis
npts_x = math.ceil(np.sqrt(num_points * Lx / Ly))
npts_y = math.ceil(num_points / npts_x)
npts_x = ceil(np.sqrt(num_points * Lx / Ly))
npts_y = ceil(num_points / npts_x)

# Recalculate the true number of points
num_points = npts_x * npts_y
Expand Down Expand Up @@ -2639,6 +2675,20 @@ def make_cart_rect(
# Shift the rectangle to the appropriate centre
rect = p0 + np.expand_dims(np.array(rect_pos), axis=1)

if plot_rect:
# create the figure
fig = plt.figure()
cmap = plt.get_cmap('viridis', np.shape(rect)[1])
if len(rect_pos) == 3:
ax = fig.add_subplot(111, projection="3d")
ax.scatter(rect[0], rect[1], rect[2], marker='s', c=np.arange(np.shape(rect)[1]), cmap=cmap, alpha=0.9, edgecolor=None)
if len(rect_pos) == 2:
ax = fig.add_subplot(111)
ax.scatter(rect[1, :], rect[0, :], marker='s', c=np.arange(np.shape(rect)[1]), cmap=cmap, alpha=0.9, edgecolor=None)
ax.invert_yaxis()
plt.grid(True)
plt.show()

return rect


Expand Down Expand Up @@ -2724,7 +2774,7 @@ def calculate_lateral_pressure() -> Float[np.ndarray, "N"]:
Z = k * lateral_positions * diameter / (2 * radius)
# TODO: this should work
# assert np.all(Z) > 0, 'Z must be greater than 0'
lateral_pressure = 2.0 * density * sound_speed * velocity * k * h * scipy.special.jv(1, Z) / Z
lateral_pressure = 2.0 * density * sound_speed * velocity * k * h * jv(1, Z) / Z

# replace origin with limit
lateral_pressure[lateral_positions == 0] = density * sound_speed * velocity * k * h
Expand Down Expand Up @@ -2905,7 +2955,7 @@ def make_cart_arc(
plot_arc: bool = False,
) -> Float[np.ndarray, "2 NumPoints"]:
"""
make_cart_arc creates a 2 x num_points array of the Cartesian
Creates a 2 x num_points array of the Cartesian
coordinates of points evenly distributed over an arc. The midpoint of
the arc is set by arc_pos. The orientation of the arc is set by
focus_pos, which corresponds to any point on the axis of the arc
Expand Down Expand Up @@ -2966,7 +3016,9 @@ def make_cart_arc(

# Create the figure
plt.figure()
plt.plot(arc[1, :] * scale, arc[0, :] * scale, "b.")
cmap = plt.get_cmap('viridis', np.shape(arc)[1])
plt.scatter(arc[1, :] * scale, arc[0, :] * scale, marker='s', c=np.arange(np.shape(arc)[1]),
cmap=cmap, alpha=0.9, edgecolor=None)
plt.gca().invert_yaxis()
plt.xlabel(f"y-position [{prefix}m]")
plt.ylabel(f"x-position [{prefix}m]")
Expand Down
7 changes: 4 additions & 3 deletions kwave/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def gaussian(

def _compute_direction(start_pos: np.ndarray, end_pos: np.ndarray) -> Tuple[np.ndarray, float]:
"""Compute normalized direction vector and magnitude between two points."""
direction = end_pos - start_pos
direction = np.asarray(end_pos) - np.asarray(start_pos)
magnitude = np.linalg.norm(direction)
direction = direction / magnitude
return direction, magnitude
Expand Down Expand Up @@ -457,11 +457,12 @@ def compute_linear_transform(pos1, pos2, offset=None):
Returns:
Tuple containing:
- 3x3 rotation matrix
- offset position

"""
rot_mat, direction = compute_rotation_between_vectors(pos1, pos2)
rot_mat, direction = compute_rotation_between_vectors(np.asarray(pos1), np.asarray(pos2))
if offset is not None:
offset_pos = pos1 + offset * direction
offset_pos = np.asarray(pos1) + offset * direction
else:
offset_pos = 0
return rot_mat, offset_pos
Loading