diff --git a/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java b/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java index 24f4e2482..48d9182c4 100644 --- a/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java +++ b/src/main/java/net/imglib2/algorithm/morphology/distance/DistanceTransform.java @@ -58,6 +58,7 @@ import net.imglib2.img.cell.CellImgFactory; import net.imglib2.type.BooleanType; import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.IntegerType; import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.integer.LongType; import net.imglib2.type.numeric.real.DoubleType; @@ -81,6 +82,7 @@ *

* * @author Philipp Hanslovsky + * @author John Bogovic */ public class DistanceTransform { @@ -391,6 +393,8 @@ public static < T extends RealType< T > > void transform( transform( source, source, d ); } + + /** * Create * distance @@ -528,7 +532,7 @@ public static < T extends RealType< T >, U extends RealType< U >, V extends Real { transformAlongDimension( ( RandomAccessible< T > ) Views.addDimension( source ), - Views.interval( Views.addDimension( target ), new FinalInterval( target.dimension( 0 ), 1 ) ), + Views.addDimension( target, 0, 0 ), d, 0 ); } @@ -601,7 +605,7 @@ public static < T extends RealType< T >, U extends RealType< U >, V extends Real { transformAlongDimensionParallel( ( RandomAccessible< T > ) Views.addDimension( source ), - Views.interval( Views.addDimension( target ), new FinalInterval( target.dimension( 0 ), 1 ) ), + Views.addDimension( target, 0, 0 ), d, 0, es, @@ -734,8 +738,7 @@ public static < B extends BooleanType< B >, U extends RealType< U >, V extends R final DISTANCE_TYPE distanceType, final double... weights ) { - - final U maxVal = Util.getTypeFromInterval( tmp ).createVariable(); + final U maxVal = tmp.getType().createVariable(); maxVal.setReal( maxVal.getMaxValue() ); final Converter< B, U > converter = new BinaryMaskToCost<>( maxVal ); final RandomAccessible< U > converted = Converters.convert( source, converter, maxVal.createVariable() ); @@ -789,7 +792,7 @@ public static < B extends BooleanType< B >, U extends RealType< U >, V extends R final int nTasks, final double... weights ) throws InterruptedException, ExecutionException { - final U maxVal = Util.getTypeFromInterval( tmp ).createVariable(); + final U maxVal = tmp.getType().createVariable(); maxVal.setReal( maxVal.getMaxValue() ); final Converter< B, U > converter = new BinaryMaskToCost<>( maxVal ); final RandomAccessible< U > converted = Converters.convert( source, converter, maxVal.createVariable() ); @@ -916,7 +919,7 @@ public static < B extends BooleanType< B >, U extends RealType< U >, V extends R final RandomAccessibleInterval< V > target, final Distance d ) { - final U maxVal = Util.getTypeFromInterval( tmp ).createVariable(); + final U maxVal = tmp.getType().createVariable(); maxVal.setReal( maxVal.getMaxValue() ); final Converter< B, U > converter = new BinaryMaskToCost<>( maxVal ); final RandomAccessible< U > converted = Converters.convert( source, converter, maxVal.createVariable() ); @@ -963,7 +966,7 @@ public static < B extends BooleanType< B >, U extends RealType< U >, V extends R final ExecutorService es, final int nTasks ) throws InterruptedException, ExecutionException { - final U maxVal = Util.getTypeFromInterval( tmp ).createVariable(); + final U maxVal = tmp.getType().createVariable(); maxVal.setReal( maxVal.getMaxValue() ); final Converter< B, U > converter = new BinaryMaskToCost<>( maxVal ); final RandomAccessible< U > converted = Converters.convert( source, converter, maxVal.createVariable() ); @@ -1005,7 +1008,7 @@ private static < T extends RealType< T >, U extends RealType< U >, V extends Rea { transformL1AlongDimension( ( RandomAccessible< T > ) Views.addDimension( source ), - Views.interval( Views.addDimension( target ), new FinalInterval( target.dimension( 0 ), 1 ) ), + Views.addDimension( target, 0, 0 ), 0, weights[ 0 ] ); } @@ -1073,7 +1076,7 @@ private static < T extends RealType< T >, U extends RealType< U >, V extends Rea { transformL1AlongDimensionParallel( ( RandomAccessible< T > ) Views.addDimension( source ), - Views.interval( Views.addDimension( target ), new FinalInterval( target.dimension( 0 ), 1 ) ), + Views.addDimension( target, 0, 0 ), 0, weights[ 0 ], es, @@ -1106,9 +1109,11 @@ private static < T extends RealType< T >, U extends RealType< U > > void transfo final int lastDim = target.numDimensions() - 1; final long size = target.dimension( dim ); final RealComposite< DoubleType > tmp = Views.collapseReal( createAppropriateOneDimensionalImage( size, new DoubleType() ) ).randomAccess().get(); + // do not permute if we already work on last dimension final Cursor< RealComposite< T > > s = Views.flatIterable( Views.collapseReal( dim == lastDim ? Views.interval( source, target ) : Views.permute( Views.interval( source, target ), dim, lastDim ) ) ).cursor(); final Cursor< RealComposite< U > > t = Views.flatIterable( Views.collapseReal( dim == lastDim ? target : Views.permute( target, dim, lastDim ) ) ).cursor(); + final RealComposite< LongType > lowerBoundDistanceIndex = Views.collapseReal( createAppropriateOneDimensionalImage( size, new LongType() ) ).randomAccess().get(); final RealComposite< DoubleType > envelopeIntersectLocation = Views.collapseReal( createAppropriateOneDimensionalImage( size + 1, new DoubleType() ) ).randomAccess().get(); @@ -1322,6 +1327,541 @@ private static < T extends NativeType< T > & RealType< T > > Img< T > createAppr return size > Integer.MAX_VALUE ? new CellImgFactory<>( t, Integer.MAX_VALUE ).create( dim ) : new ArrayImgFactory<>( t ).create( dim ); } + /** + * Compute the distance and nearest neighbors of a label image. + *

+ * Returns the joint distance transform of all non-background labels in the + * {@code labels} image. This distance transform will be the distance to the + * nearest non-background label at every point. + *

+ * Simultaneously, modifies the {@code labels} image so that the label at + * every point is the closest label, where initial distances are given by + * the values in the distance argument. + *

+ * This method uses 0 (zero) as the background label. + * + * @param + * the label type + * @param labels + * the label image + * @param weights + * for distance computation per dimension + * @return the distance map + */ + public static < L extends IntegerType< L > > RandomAccessibleInterval< DoubleType > voronoiDistanceTransform( + final RandomAccessibleInterval< L > labels, + final double... weights ) + { + return voronoiDistanceTransform( labels, 0, weights ); + } + + /** + * Compute the distance and nearest neighbors of a label image. + *

+ * Returns the joint distance transform of all non-background labels in the + * {@code labels} image. This distance transform will be the distance to the + * nearest non-background label at every point. + *

+ * Simultaneously, modifies the {@code labels} image so that the label at + * every point is the closest label, where initial distances are given by + * the values in the distance argument. + * + * @param + * the label type + * @param labels + * the label image + * @param backgroundLabel + * the background label + * @param weights + * for distance computation per dimension + * @return the distance map + */ + public static < L extends IntegerType< L > > RandomAccessibleInterval< DoubleType > voronoiDistanceTransform( + final RandomAccessibleInterval< L > labels, + final long backgroundLabel, + final double... weights ) + { + final RandomAccessibleInterval< DoubleType > distance = makeDistances( backgroundLabel, labels, new DoubleType() ); + voronoiDistanceTransform( labels, distance, weights ); + return distance; + } + + /** + * Compute the distance and nearest neighbors of a label image in parallel. + *

+ * Returns the joint distance transform of all non-background labels in the + * {@code labels} image. This distance transform will be the distance to the + * nearest non-background label at every point. + *

+ * Simultaneously, modifies the {@code labels} image so that the label at + * every point is the closest label, where initial distances are given by + * the values in the distance argument. + * + * @param + * the label type + * @param labels + * the label image + * @param backgroundLabel + * the background label + * @param es + * the ExecutorService + * @param nTasks + * the number of tasks in which to split the computation + * @param weights + * for distance computation per dimension + * @return the distance map + * @throws ExecutionException + * @throws InterruptedException + */ + public static < L extends IntegerType< L > > RandomAccessibleInterval< DoubleType > voronoiDistanceTransform( + final RandomAccessibleInterval< L > labels, + final long backgroundLabel, + final ExecutorService es, + final int nTasks, + final double... weights ) throws InterruptedException, ExecutionException + { + final RandomAccessibleInterval< DoubleType > distance = makeDistances( backgroundLabel, labels, new DoubleType() ); + labelTransform( labels, distance, es, nTasks, weights ); + return distance; + } + + /** + * Compute the distance transform, of the given distance image, storing the + * result in place. Simultaneously, modifies the labels image so that the + * value at every point is the closest label, where initial distances are + * given by the values in the distance argument. + * + * @param + * label type + * @param + * distance type + * @param labels + * the label image + * @param distance + * the distance image + * @param weights + * for distance computation per dimension + */ + public static < L extends IntegerType< L >, T extends RealType< T > > void voronoiDistanceTransform( + final RandomAccessibleInterval< L > labels, + final RandomAccessibleInterval< T > distance, + final double... weights ) + { + final Distance distanceFun = createEuclideanDistance(labels.numDimensions(), weights) ; + transformPropagateLabels( distance, distance, distance, labels, labels, distanceFun ); + } + + /** + * Compute the distance and nearest neighbors of a label image in parallel. + *

+ * Returns the joint distance transform of all non-background labels in the + * {@code labels} image. This distance transform will be the distance to the + * nearest non-background label at every point. + *

+ * Simultaneously, modifies the {@code labels} image so that the label at + * every point is the closest label, where initial distances are given by + * the values in the distance argument. + * + * @param + * the label type + * @param labels + * the label image + * @param es + * the ExecutorService + * @param nTasks + * the number of tasks in which to split the computation + * @param weights + * for distance computation per dimension + * @throws ExecutionException + * @throws InterruptedException + */ + public static < L extends IntegerType< L >, T extends RealType< T > > void labelTransform( + final RandomAccessibleInterval< L > labels, + final RandomAccessibleInterval< T > distance, + final ExecutorService es, + final int nTasks, + final double... weights ) throws InterruptedException, ExecutionException + { + final Distance distanceFun = createEuclideanDistance(labels.numDimensions(), weights) ; + transformPropagateLabels( distance, distance, distance, labels, labels, distanceFun, es, nTasks ); + } + + /** + * Computes the distance transform of the input distance map {@code distance} + * and simultaneously propagate a set of corresponding {@code labels}. + * Distance results are stored in the input {@code targetDistance} image. + *

+ * Simultaneously, propagates labels stored in the {@code labels} image so + * that the label at every point is the closest label, where initial + * distances are given by the values in the distance argument. Results of + * label propagation are stored in {@code labelsResult} + *

+ * This implementation operates in-place for both the {@code distance} and + * {@code labels} images. It uses an isotropic distance function with distance 1 + * between samples. + * + * @param distance + * Input function on which distance transform should be computed. + * @param labels + * Labels to be propagated. + * @param + * {@link RealType} input + * @param + * {@link IntegerType} the label type + */ + public static < T extends RealType< T >, L extends IntegerType< L > > void transformPropagateLabels( + final RandomAccessibleInterval< T > distance, + final RandomAccessibleInterval< L > labels) + { + transformPropagateLabels( distance, distance, labels, labels, new EuclidianDistanceIsotropic( 1 ) ); + } + + /** + * Computes the distance transform of the input distance map {@code distance} + * and simultaneously propagate a set of corresponding {@code labels}. + * Distance results are stored in the input {@code targetDistance} image. + *

+ * Simultaneously, propagates labels stored in the {@code labels} image so + * that the label at every point is the closest label, where initial + * distances are given by the values in the distance argument. Results of + * label propagation are stored in {@code labelsResult} + *

+ * This implementation uses the {@code targetDistance} for temporary storage. + * + * @param + * input distance {@link RealType} + * @param + * output distance {@link RealType} + * @param + * input label {@link IntegerType} + * @param + * output label {@link IntegerType} + * @param distance + * Input distance function from which distance transform should + * be computed. + * @param targetDistance + * Final result of distance transform. May be the same instance + * as distance. + * @param labels + * the image of labels to be propagated + * @param labelsResult + * the image in which to store the result of label propagation. + * May be the same instance as labels + * @param d + * the {@link Distance} function + */ + public static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType, M extends IntegerType > void transformPropagateLabels( + final RandomAccessible< T > distance, + final RandomAccessibleInterval< U > targetDistance, + final RandomAccessible< L > labels, + final RandomAccessibleInterval< M > labelsResult, + final Distance d ) + { + transformPropagateLabels( distance, targetDistance, targetDistance, labels, labelsResult, d ); + } + + /** + * Computes the distance transform of the input distance map {@code distance} + * and simultaneously propagate a set of corresponding {@code labels}. + * Distance results are stored in the input {@code targetDistance} image. + *

+ * Simultaneously, propagates labels stored in the {@code labels} image so + * that the label at every point is the closest label, where initial + * distances are given by the values in the distance argument. Results of + * label propagation are stored in {@code labelsResult} + * + * @param + * input distance {@link RealType} + * @param + * intermediate distance result {@link RealType} + * @param + * output distance {@link RealType} + * @param + * input label {@link IntegerType} + * @param + * output label {@link IntegerType} + * @param distance + * Input distance function from which distance transform should + * be computed. + * @param tmpDistance + * Storage for intermediate distance results. + * @param targetDistance + * Final result of distance transform. May be the same instance + * as distance. + * @param labels + * the image of labels to be propagated + * @param labelsResult + * the image in which to store the result of label propagation. + * May be the same instance as labels + * @param d + * {@link Distance} between two points. + */ + public static < T extends RealType< T >, U extends RealType< U >, V extends RealType< V >, L extends IntegerType, M extends IntegerType > void transformPropagateLabels( + final RandomAccessible< T > distance, + final RandomAccessibleInterval< U > tmpDistance, + final RandomAccessibleInterval< V > targetDistance, + final RandomAccessible< L > labels, + final RandomAccessibleInterval< M > labelsResult, + final Distance d ) + { + assert distance.numDimensions() == targetDistance.numDimensions(): "Dimension mismatch"; + final int nDim = distance.numDimensions(); + final int lastDim = nDim - 1; + + if ( nDim == 1 ) + { + transformAlongDimensionPropagateLabels( + ( RandomAccessible< T > ) Views.addDimension( distance ), + Views.addDimension( targetDistance, 0, 0 ), + Views.addDimension( labels ), + Views.addDimension( labelsResult ), + d, 0 ); + } + else + { + transformAlongDimensionPropagateLabels( distance, tmpDistance, labels, labelsResult, d, 0 ); + } + + for ( int dim = 1; dim < nDim; ++dim ) + { + if ( dim == lastDim ) + { + transformAlongDimensionPropagateLabels( tmpDistance, targetDistance, labels, labelsResult, d, dim ); + } + else + { + transformAlongDimensionPropagateLabels( tmpDistance, tmpDistance, labels, labelsResult, d, dim ); + } + } + } + + /** + * In parallel, computes the distance transform of the input distance map {@code distance} + * and simultaneously propagate a set of corresponding {@code labels}. + * Distance results are stored in the input {@code targetDistance} image. + *

+ * Simultaneously, propagates labels stored in the {@code labels} image so + * that the label at every point is the closest label, where initial + * distances are given by the values in the distance argument. Results of + * label propagation are stored in {@code labelsResult} + * + * @param + * input distance {@link RealType} + * @param + * intermediate distance result {@link RealType} + * @param + * output distance {@link RealType} + * @param + * input label {@link IntegerType} + * @param + * output label {@link IntegerType} + * @param distance + * Input distance function from which distance transform should + * be computed. + * @param tmpDistance + * Storage for intermediate distance results. + * @param targetDistance + * Final result of distance transform. May be the same instance + * as distance. + * @param labels + * the image of labels to be propagated + * @param labelsResult + * the image in which to store the result of label propagation. + * May be the same instance as labels + * @param d + * {@link Distance} between two points. + * @param es + * the ExecutorService + * @param nTasks + * the number of tasks in which to split the computation + * @throws ExecutionException + * @throws InterruptedException + */ + public static < T extends RealType< T >, U extends RealType< U >, V extends RealType< V >, L extends IntegerType, M extends IntegerType > void transformPropagateLabels( + final RandomAccessible< T > distance, + final RandomAccessibleInterval< U > tmpDistance, + final RandomAccessibleInterval< V > targetDistance, + final RandomAccessible< L > labels, + final RandomAccessibleInterval< M > labelsResult, + final Distance d, + final ExecutorService es, + final int nTasks) throws InterruptedException, ExecutionException + { + assert distance.numDimensions() == targetDistance.numDimensions(): "Dimension mismatch"; + final int nDim = distance.numDimensions(); + final int lastDim = nDim - 1; + + if ( nDim == 1 ) + { + transformAlongDimensionPropagateLabels( + ( RandomAccessible< T > ) Views.addDimension( distance ), + Views.addDimension( targetDistance, 0, 0 ), + Views.addDimension( labels ), + Views.addDimension( labelsResult ), + d, 0 ); + } + else + { + transformAlongDimensionPropagateLabelsParallel( distance, tmpDistance, labels, labelsResult, d, 0, es, nTasks ); + } + + for ( int dim = 1; dim < nDim; ++dim ) + { + if ( dim == lastDim ) + { + transformAlongDimensionPropagateLabelsParallel( tmpDistance, targetDistance, labels, labelsResult, d, dim, es, nTasks); + } + else + { + transformAlongDimensionPropagateLabelsParallel( tmpDistance, tmpDistance, labels, labelsResult, d, dim, es, nTasks); + } + } + } + + private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType< L >, M extends IntegerType< M > > void transformAlongDimensionPropagateLabels( + final RandomAccessible< T > source, + final RandomAccessibleInterval< U > target, + final RandomAccessible< L > labelSource, + final RandomAccessible< M > labelTarget, + final Distance d, + final int dim ) + { + final int lastDim = target.numDimensions() - 1; + final long size = target.dimension( dim ); + + final Img< DoubleType > tmpImg = createAppropriateOneDimensionalImage( size, new DoubleType() ); + final RealComposite< DoubleType > tmp = Views.collapseReal( tmpImg ).randomAccess().get(); + + final Img< L > tmpLabelImg = Util.getSuitableImgFactory( tmpImg, labelSource.getType() ).create( tmpImg ); + final RealComposite< L > tmpLabel = Views.collapseReal( tmpLabelImg ).randomAccess().get(); + + // do not permute if we already work on last dimension + final Cursor< RealComposite< T > > s = Views.flatIterable( Views.collapseReal( dim == lastDim ? Views.interval( source, target ) : Views.permute( Views.interval( source, target ), dim, lastDim ) ) ).cursor(); + final Cursor< RealComposite< U > > t = Views.flatIterable( Views.collapseReal( dim == lastDim ? target : Views.permute( target, dim, lastDim ) ) ).cursor(); + + final Cursor< RealComposite< L > > ls = Views.flatIterable( + Views.collapseReal( dim == lastDim ? Views.interval( labelSource, target ) : Views.permute( Views.interval( labelSource, target ), dim, lastDim ) ) ).cursor(); + + final Cursor< RealComposite< M > > lt = Views.flatIterable( + Views.collapseReal( dim == lastDim ? Views.interval( labelTarget, target ) : Views.permute( Views.interval( labelTarget, target ), dim, lastDim ) ) ).cursor(); + + final RealComposite< LongType > lowerBoundDistanceIndex = Views.collapseReal( createAppropriateOneDimensionalImage( size, new LongType() ) ).randomAccess().get(); + final RealComposite< DoubleType > envelopeIntersectLocation = Views.collapseReal( createAppropriateOneDimensionalImage( size + 1, new DoubleType() ) ).randomAccess().get(); + + while ( s.hasNext() ) + { + final RealComposite< T > sourceComp = s.next(); + final RealComposite< U > targetComp = t.next(); + final RealComposite< L > labelComp = ls.next(); + final RealComposite< M > labelTargetComp = lt.next(); + for ( long i = 0; i < size; ++i ) + { + tmp.get( i ).set( sourceComp.get( i ).getRealDouble() ); + tmpLabel.get( i ).setInteger( labelComp.get( i ).getIntegerLong() ); + } + transformSingleColumnPropagateLabels( tmp, targetComp, tmpLabel, labelTargetComp, lowerBoundDistanceIndex, envelopeIntersectLocation, d, dim, size ); + } + } + + private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType, M extends IntegerType > void transformSingleColumnPropagateLabels( + final RealComposite< T > source, + final RealComposite< U > target, + final RealComposite< L > labelsSource, + final RealComposite< M > labelsResult, + final RealComposite< LongType > lowerBoundDistanceIndex, + final RealComposite< DoubleType > envelopeIntersectLocation, + final Distance d, + final int dim, + final long size ) + { + long k = 0; + + lowerBoundDistanceIndex.get( 0 ).set( 0 ); + envelopeIntersectLocation.get( 0 ).set( Double.NEGATIVE_INFINITY ); + envelopeIntersectLocation.get( 1 ).set( Double.POSITIVE_INFINITY ); + for ( long position = 1; position < size; ++position ) + { + long envelopeIndexAtK = lowerBoundDistanceIndex.get( k ).get(); + final double sourceAtPosition = source.get( position ).getRealDouble(); + double s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim ); + + for ( double envelopeValueAtK = envelopeIntersectLocation.get( k ).get(); s <= envelopeValueAtK; envelopeValueAtK = envelopeIntersectLocation.get( k ).get() ) + { + --k; + envelopeIndexAtK = lowerBoundDistanceIndex.get( k ).get(); + s = d.intersect( envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), position, sourceAtPosition, dim ); + } + ++k; + lowerBoundDistanceIndex.get( k ).set( position ); + envelopeIntersectLocation.get( k ).set( s ); + envelopeIntersectLocation.get( k + 1 ).set( Double.POSITIVE_INFINITY ); + } + + k = 0; + for ( long position = 0; position < size; ++position ) + { + while ( envelopeIntersectLocation.get( k + 1 ).get() < position ) + { + ++k; + } + final long envelopeIndexAtK = lowerBoundDistanceIndex.get( k ).get(); + // copy necessary because of the following line, access to source + // after write to source -> source and target cannot be the same + target.get( position ).setReal( d.evaluate( position, envelopeIndexAtK, source.get( envelopeIndexAtK ).getRealDouble(), dim ) ); + labelsResult.get( position ).setInteger( labelsSource.get( envelopeIndexAtK ).getIntegerLong() ); + } + + } + + private static < T extends RealType< T >, U extends RealType< U >, L extends IntegerType< L >, M extends IntegerType< M > > void transformAlongDimensionPropagateLabelsParallel( + final RandomAccessible< T > source, + final RandomAccessibleInterval< U > target, + final RandomAccessible< L > labelSource, + final RandomAccessible< M > labelTarget, + final Distance d, + final int dim, + final ExecutorService es, + final int nTasks ) throws InterruptedException, ExecutionException + { + int largestDim = getLargestDimension( Views.hyperSlice( target, dim, target.min( dim ) ) ); + // ignore dimension along which we calculate transform + if ( largestDim >= dim ) + { + largestDim += 1; + } + final long size = target.dimension( dim ); + final long stepPerChunk = Math.max( size / nTasks, 1 ); + + final long[] min = Intervals.minAsLongArray( target ); + final long[] max = Intervals.maxAsLongArray( target ); + + final long largestDimMin = target.min( largestDim ); + final long largestDimMax = target.max( largestDim ); + + final ArrayList< Callable< Void > > tasks = new ArrayList<>(); + for ( long m = largestDimMin, M = largestDimMin + stepPerChunk - 1; m <= largestDimMax; m += stepPerChunk, M += stepPerChunk ) + { + min[ largestDim ] = m; + max[ largestDim ] = Math.min( M, largestDimMax ); + final Interval fi = new FinalInterval( min, max ); + tasks.add( () -> { + transformAlongDimensionPropagateLabels( source, Views.interval( target, fi ), + labelSource, Views.interval( labelTarget, fi ), + d, dim ); + return null; + } ); + } + + invokeAllAndWait( es, tasks ); + } + + private static Distance createEuclideanDistance( int numDimensions, double... weights ) + { + final boolean isIsotropic = weights.length <= 1; + final double[] w = weights.length == numDimensions ? weights : DoubleStream.generate( () -> weights.length == 0 ? 1.0 : weights[ 0 ] ).limit( numDimensions ).toArray(); + return isIsotropic ? new EuclidianDistanceIsotropic( w[ 0 ] ) : new EuclidianDistanceAnisotropic( w ); + } + /** * Convenience method to find largest dimension of {@link Interval} * interval. @@ -1336,6 +1876,23 @@ public static int getLargestDimension( final Interval interval ) return IntStream.range( 0, interval.numDimensions() ).mapToObj( i -> new ValuePair<>( i, interval.dimension( i ) ) ).max( ( p1, p2 ) -> Long.compare( p1.getB(), p2.getB() ) ).get().getA(); } + private static < L extends IntegerType< L >, T extends RealType< T > > RandomAccessibleInterval< T > makeDistances( + final long label, + final RandomAccessibleInterval< L > labels, + final T type) { + + final T maxVal = type.copy(); + maxVal.setReal( maxVal.getMaxValue() ); + + final RandomAccessibleInterval< T > distances = Util.getSuitableImgFactory( labels, type ).create( labels ); + Views.pair( labels, distances ).view().interval( labels ).forEach( pair -> { + if( pair.getA().getIntegerLong() == label ) + pair.getB().set( maxVal ); + }); + + return distances; + } + private static class BinaryMaskToCost< B extends BooleanType< B >, R extends RealType< R > > implements Converter< B, R > { diff --git a/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformTest.java b/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformTest.java index ff623beaf..9129ad728 100644 --- a/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformTest.java +++ b/src/test/java/net/imglib2/algorithm/morphology/distance/DistanceTransformTest.java @@ -34,8 +34,15 @@ package net.imglib2.algorithm.morphology.distance; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; +import java.util.List; import java.util.Random; +import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -46,6 +53,7 @@ import org.junit.Test; import net.imglib2.Cursor; +import net.imglib2.Interval; import net.imglib2.Localizable; import net.imglib2.Point; import net.imglib2.RandomAccess; @@ -58,16 +66,20 @@ import net.imglib2.img.basictypeaccess.array.DoubleArray; import net.imglib2.img.basictypeaccess.array.LongArray; import net.imglib2.type.logic.BitType; +import net.imglib2.type.numeric.IntegerType; import net.imglib2.type.numeric.RealType; +import net.imglib2.type.numeric.integer.LongType; import net.imglib2.type.numeric.real.DoubleType; import net.imglib2.util.Intervals; +import net.imglib2.util.Localizables; import net.imglib2.util.Pair; +import net.imglib2.util.Util; import net.imglib2.view.Views; /** * * @author Philipp Hanslovsky - * + * @author John Bogovic */ public class DistanceTransformTest { @@ -159,9 +171,7 @@ private void testBinary( final DISTANCE_TYPE dt, final DistanceCalculator distan private static void compareRAIofRealType( final RandomAccessibleInterval< ? extends RealType< ? > > ref, final RandomAccessibleInterval< ? extends RealType< ? > > comp, final double tolerance ) { - Assert.assertArrayEquals( Intervals.dimensionsAsLongArray( ref ), Intervals.dimensionsAsLongArray( comp ) ); - Assert.assertArrayEquals( Intervals.minAsLongArray( ref ), Intervals.minAsLongArray( comp ) ); - Assert.assertArrayEquals( Intervals.maxAsLongArray( ref ), Intervals.maxAsLongArray( comp ) ); + assertTrue( Intervals.equals( ref, comp ) ); for ( final Pair< ? extends RealType< ? >, ? extends RealType< ? > > p : Views.flatIterable( Views.interval( Views.pair( ref, comp ), ref ) ) ) { Assert.assertEquals( p.getA().getRealDouble(), p.getB().getRealDouble(), tolerance ); @@ -440,7 +450,7 @@ private static < T extends RealType< T > > void checkDistance( final double[] weights, final DistanceCalculator distanceCalculator ) { - for ( final Cursor< T > c = Views.iterable( dist ).localizingCursor(); c.hasNext(); ) + for ( final Cursor< T > c = dist.localizingCursor(); c.hasNext(); ) { final double actual = c.next().getRealDouble(); final double expected = atSamePosition( foreground, c ) ? 0.0 : distanceCalculator.dist( foreground, c, weights ); @@ -448,4 +458,186 @@ private static < T extends RealType< T > > void checkDistance( } } + @Test + public void testLabelPropagation() + { + /* + * Iterate over numReplicates = [0..9] numDimensions = [2, 3] numLabels + * = [1..5] + */ + final int firstReplicate = 0; + final int lastReplicate = 9; + + final int firstNumDimensions = 2; + final int lastNumDimensions = 3; + + final int firstNumLabels = 2; + final int lastNumLabels = 5; + + final RandomAccessibleInterval< Localizable > parameters = Localizables.randomAccessibleInterval( + Intervals.createMinMax( + firstReplicate, firstNumDimensions, + firstNumLabels, lastReplicate, + lastNumDimensions, lastNumLabels ) ); + + parameters.forEach( params -> { + + @SuppressWarnings( "unused" ) + final int replicate = params.getIntPosition( 0 ); + final int numDimensions = params.getIntPosition( 1 ); + final int numLabels = params.getIntPosition( 2 ); + + testLabelPropagationHelper( numDimensions, numLabels ); + try + { + testLabelPropagationHelperParallel( numDimensions, numLabels ); + } + catch ( Exception e ) + { + e.printStackTrace(); + fail(); + } + } ); + + } + + /** + * Creates an label and distances images with the requested number of dimensions (ndims), + * and places nLabels points with non-zero label. Checks that the propagated labels correctly + * reflect the nearest label (ties are allowed: any label equi-distant to a point passes). + * + * @param ndims number of dimensions + * @param nLabels number of labels + */ + private void testLabelPropagationHelper( int ndims, int nLabels ) + { + + final long[] imgDims = LongStream.iterate( dimensionSize, d -> d - 1 ).limit( ndims ).toArray(); + final ArrayImg< LongType, LongArray > labels = ArrayImgs.longs( imgDims ); + + final Set< PointAndLabel > points = initializeLabels( rng, nLabels, labels ); + DistanceTransform.voronoiDistanceTransform( labels, 0 ); + validateLabelsSet( "serial", points, labels ); + } + + /** + * Creates an label and distances images with the requested number of dimensions (ndims), + * and places nLabels points with non-zero label. Checks that the propagated labels correctly + * reflect the nearest label (ties are allowed: any label equi-distant to a point passes). + * + * @param ndims number of dimensions + * @param nLabels number of labels + * @throws ExecutionException + * @throws InterruptedException + */ + private void testLabelPropagationHelperParallel( int ndims, int nLabels ) throws InterruptedException, ExecutionException + { + + final long[] imgDims = LongStream.iterate( dimensionSize, d -> d - 1 ).limit( ndims ).toArray(); + final ArrayImg< LongType, LongArray > labels = ArrayImgs.longs( imgDims ); + final Set< PointAndLabel > points = initializeLabels( rng, nLabels, labels ); + DistanceTransform.voronoiDistanceTransform( labels, 0, es, 3 * nThreads ); + validateLabelsSet( "parallel", points, labels ); + } + + private ArrayImg< LongType, LongArray > copyLongArrayImg( ArrayImg< LongType, LongArray > img ) + { + + final long[] dataOrig = img.getAccessType().getCurrentStorageArray(); + final long[] dataCopy = new long[ dataOrig.length ]; + System.arraycopy( dataOrig, 0, dataCopy, 0, dataOrig.length ); + return ArrayImgs.longs( dataCopy, img.dimensionsAsLongArray() ); + } + + private static Point randomPointInInterval( final Random rng, final Interval itvl ) + { + final int[] coords = IntStream.range( 0, itvl.numDimensions() ).map( i -> { + return rng.nextInt( ( int ) itvl.dimension( i ) ); + } ).toArray(); + return new Point( coords ); + } + + private static < T extends RealType< T >, L extends IntegerType< L > > Set< PointAndLabel > initializeLabels( Random random, int numLabels, RandomAccessibleInterval< L > labels ) + { + labels.forEach( p -> p.setZero() ); // Initialize all labels to 0 + Set< PointAndLabel > positions = new HashSet<>(); + + int currentLabel = 1; + // Set numLabels different random positions to a non-zero label + while ( positions.size() < numLabels ) + { + final Point pt = randomPointInInterval( random, labels ); + if ( !positions.contains( pt ) ) + { + + final PointAndLabel candidate = new PointAndLabel( currentLabel, pt.positionAsLongArray() ); + if ( !positions.contains( candidate ) ) + { + positions.add( candidate ); + labels.randomAccess().setPositionAndGet( pt ).setInteger( currentLabel ); + currentLabel++; + } + + } + } + return positions; + } + + /** + * Return the set of points within epsilon distance of the query point + * + * @param query point + * @param pointSet set of candidate points + * @param epsilon distance threshold + * @return the set of close points + */ + private static List< PointAndLabel > closestSet( Localizable query, Set< PointAndLabel > pointSet, final double epsilon ) + { + + final List< PointAndLabel > listOfEquidistant = new ArrayList<>(); + + double mindist = Double.MAX_VALUE; + for ( PointAndLabel pt : pointSet ) + { + double dist = Util.distance( query, pt ); + + if ( Math.abs( dist - mindist ) < epsilon ) + { + listOfEquidistant.add( pt ); + } + else if ( dist < mindist ) + { + mindist = dist; + listOfEquidistant.clear(); + listOfEquidistant.add( pt ); + } + } + + return listOfEquidistant; + } + + private static < T extends RealType< T >, L extends IntegerType< L > > void validateLabelsSet( final String prefix, final Set< PointAndLabel > points, final RandomAccessibleInterval< L > labels ) + { + final double EPS = 0.01; + final Cursor< L > c = labels.cursor(); + while ( c.hasNext() ) + { + c.fwd(); + final boolean labelIsClosest = closestSet( c, points, EPS ).stream().anyMatch( p -> p.label == c.get().getIntegerLong() ); + assertTrue( prefix + " point: " + Arrays.toString( c.positionAsLongArray() ), labelIsClosest ); + } + } + + private static class PointAndLabel extends Point + { + + long label; + + public PointAndLabel( long label, long[] position ) + { + super( position ); + this.label = label; + } + } + }