Skip to content

Commit

Permalink
Merge pull request #539 from saalfeldlab/perf/1.3.1
Browse files Browse the repository at this point in the history
1.3.1
  • Loading branch information
cmhulbert authored Jun 13, 2024
2 parents bc3b5a8 + 3c643d4 commit 687e4f9
Show file tree
Hide file tree
Showing 15 changed files with 131 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ public int paint(
// if rendering was not cancelled...
if (success) {
if (createProjector) {
if (currentScreenScaleIndex >= screenImages.size() || reuseBufferScreenScale >= screenImages.size())
return -1;
final ArrayDeque<T> buffers = screenImages.get(currentScreenScaleIndex);
final T renderTarget = doubleBuffered ? buffers.pop() : buffers.peek();
final T unusedBuffer = display.setBufferedImageAndTransform(renderTarget, currentProjectorTransform);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,36 +45,37 @@ public void run() {
while (this.isRunning) {
boolean paint;
synchronized (this) {
paint = this.pleaseRepaint;
this.pleaseRepaint = false;
paint = pleaseRepaint;
pleaseRepaint = false;
}

if (paint) {
try {
this.paintable.paint();
paintable.paint();
} catch (RejectedExecutionException var5) {
}
}

synchronized (this) {
try {
if (this.isRunning && !this.pleaseRepaint) {
this.wait();
if (isRunning && !pleaseRepaint) {
wait();
}
continue;
if (isRunning)
continue;
} catch (InterruptedException var7) {
System.out.println(Thread.currentThread().getName() + " interrupted");
}
}

return;
}
}

public void requestRepaint() {

synchronized (this) {
this.pleaseRepaint = true;
this.notify();
pleaseRepaint = true;
notify();
}
}

Expand All @@ -85,12 +86,10 @@ public interface Paintable {

public void stopRendering() {

synchronized (this) {
LOG.debug("Stop rendering now!");
this.isRunning = false;
this.notify();
LOG.debug("Notified on this ({})", this);
}
LOG.debug("Stop rendering now!");
isRunning = false;
interrupt();
LOG.debug("Notified on this ({})", this);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,8 @@ public void resetMasks(final boolean clearOldMask) throws MaskInUse {
if (!canResetMask)
throw new MaskInUse("Cannot reset the mask.");

var mask = getCurrentMask();
if (mask != null && mask.shutdown != null) mask.shutdown.run();
setCurrentMask(null);
this.isBusy.set(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ static long randomTemporaryId() {
return randomTemps.findFirst().getAsLong();
}

static boolean isTemporary(final long id) {

return id > FIRST_TEMPORARY_ID && id < LAST_TEMPORARY_ID;
}

class IdServiceNotProvided implements IdService {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ public List<ActionSet> makeActionSets(KeyTracker keyTracker, Supplier<ViewerPane
};
}

public long activateCurrentOrNext() {
if (selectedIds.isLastSelectionValid())
return selectedIds.getLastSelection();
else return nextId(true);
}

public long nextId() {

return nextId(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ public abstract class AbstractHighlightingARGBStream extends ObservableWithListe
private final BooleanProperty colorFromSegmentId = new SimpleBooleanProperty();

protected final TLongIntHashMap explicitlySpecifiedColors = new TLongIntHashMap();
protected final TLongIntHashMap overrideAlpha = new TLongIntHashMap(
public final TLongIntHashMap overrideAlpha = new TLongIntHashMap(
Constants.DEFAULT_CAPACITY,
Constants.DEFAULT_LOAD_FACTOR,
Label.INVALID,
0
-1
);

public AbstractHighlightingARGBStream(
Expand Down Expand Up @@ -295,8 +295,10 @@ public boolean getHideLockedSegments() {
public void specifyColorExplicitly(final long segmentId, final int color, final boolean overrideAlpha) {

this.explicitlySpecifiedColors.put(segmentId, color);
if (overrideAlpha)
this.overrideAlpha.put(segmentId, 1);
if (overrideAlpha) {
var alpha = color >>> 24;
this.overrideAlpha.put(segmentId, alpha);
}
clearCache();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,18 @@ protected int argbImpl(final long fragmentId, final boolean colorFromSegmentId)
} else if (lockedSegments.isLocked(segmentId) && hideLockedSegments) {
argb = argb & 0x00ffffff;
} else {
if (overrideAlpha.get(assigned) != 1) {
var a = alpha;
int alphaOverride = overrideAlpha.get(assigned);
int actualAlpha = alpha;
if (alphaOverride == overrideAlpha.getNoEntryValue()) {
if (isActiveSegment) {
if (isActiveFragment(fragmentId))
a = activeFragmentAlpha;
actualAlpha = activeFragmentAlpha;
else
a = activeSegmentAlpha;
actualAlpha = activeSegmentAlpha;
}
argb = argb & 0x00ffffff | a;
}
} else
actualAlpha = alphaOverride << 24;
argb = argb & 0x00ffffff | actualAlpha;
}

return argb;
Expand Down
9 changes: 5 additions & 4 deletions src/main/java/org/janelia/saalfeldlab/util/n5/N5Data.java
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ public static void createEmptyLabelDataset(
throw new IOException(String.format("Unique labels group `%s' exists in container `%s' -- conflict likely.", uniqueLabelsGroup, container));

n5.setAttribute(group, N5Helpers.PAINTERA_DATA_KEY, pd);
n5.setAttribute(group, N5Helpers.MAX_ID_KEY, 1L);
n5.setAttribute(group, N5Helpers.MAX_ID_KEY, 0L);

final String dataGroup = String.format("%s/data", group);
n5.createGroup(dataGroup);
Expand Down Expand Up @@ -828,15 +828,16 @@ public static void createEmptyLabelDataset(

final String dataset = String.format(scaleDatasetPattern, scaleLevel);
final String uniqeLabelsDataset = String.format(scaleUniqueLabelsPattern, scaleLevel);
n5.createDataset(dataset, scaledDimensions, blockSize, DataType.UINT8, new GzipCompression());
n5.createDataset(uniqeLabelsDataset, scaledDimensions, blockSize, DataType.UINT64, new GzipCompression());

if (labelMultisetType) {
n5.createDataset(dataset, scaledDimensions, blockSize, DataType.UINT8, new GzipCompression());
final int maxNum = downscaledLevel < 0 ? -1 : maxNumEntries[downscaledLevel];
n5.setAttribute(dataset, N5Helpers.MAX_NUM_ENTRIES_KEY, maxNum);
n5.setAttribute(dataset, N5Helpers.IS_LABEL_MULTISET_KEY, true);
}
} else
n5.createDataset(dataset, scaledDimensions, blockSize, DataType.UINT64, new GzipCompression());

n5.createDataset(uniqeLabelsDataset, scaledDimensions, blockSize, DataType.UINT64, new GzipCompression());
if (scaleLevel != 0) {
n5.setAttribute(dataset, N5Helpers.DOWNSAMPLING_FACTORS_KEY, accumulatedFactors);
n5.setAttribute(uniqeLabelsDataset, N5Helpers.DOWNSAMPLING_FACTORS_KEY, accumulatedFactors);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ object SamEmbeddingLoaderCache : AsyncCacheWithLoader<RenderUnitState, OnnxTenso
result.image?.let { img ->
ImageIO.write(SwingFXUtils.fromFXImage(img, null), "png", predictionImagePngOutputStream)
predictionImagePngOutputStream.close()
sharedQueue.shutdown()
imageRenderer.stopRendering()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ class ShapeInterpolationController<D : IntegerType<D>>(

/* Replace old slice info */
slicesAndInterpolants.removeSlice(oldSlice)
oldSlice.mask.shutdown?.run()

val newSlice = SliceInfo(
newMask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ class ShapeInterpolationMode<D : IntegerType<D>>(val controller: ShapeInterpolat

override fun exit() {
super.exit()

SamEmbeddingLoaderCache.stopNavigationBasedRequests()
SamEmbeddingLoaderCache.invalidateAll()
paintera.baseView.disabledPropertyBindings.remove(controller)
controller.resetFragmentAlpha()
activeViewerProperty.removeListener(toolTriggerListener)
Expand Down Expand Up @@ -559,9 +561,36 @@ internal class SamSliceCache : HashMap<Float, SamSliceInfo>() {
operator fun minusAssign(key: Double) {
remove(key.toFloat())
}

operator fun minusAssign(key: Float) {
remove(key)
}

override fun clear() {
values.forEach {
it.mask.shutdown?.run()
}
super.clear()
}

override fun remove(key: Float): SamSliceInfo? {
return super.remove(key)?.also {
it.mask.shutdown?.run()
}
}

override fun remove(key: Float, value: SamSliceInfo): Boolean {
return if (super.remove(key, value)) {
value.mask.shutdown?.run()
true
} else false
}

override fun put(key: Float, value: SamSliceInfo): SamSliceInfo? {
return super.put(key, value)?.also {
it.mask.shutdown?.run()
}
}
}

internal data class SamSliceInfo(val renderState: RenderUnitState, val mask: ViewerMask, var prediction: SamPredictor.PredictionRequest, var sliceInfo: ShapeInterpolationController.SliceInfo?, var locked: Boolean = false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ import net.imglib2.view.Views
import org.apache.commons.io.output.NullPrintStream
import org.janelia.saalfeldlab.bdv.fx.viewer.ViewerPanelFX
import org.janelia.saalfeldlab.control.VPotControl
import org.janelia.saalfeldlab.fx.Tasks
import org.janelia.saalfeldlab.fx.UtilityTask
import org.janelia.saalfeldlab.fx.actions.*
import org.janelia.saalfeldlab.fx.actions.ActionSet.Companion.installActionSet
import org.janelia.saalfeldlab.fx.extensions.LazyForeignValue
Expand Down Expand Up @@ -94,7 +92,6 @@ import org.janelia.saalfeldlab.paintera.util.IntervalHelpers.Companion.smallestC
import org.janelia.saalfeldlab.paintera.util.algorithms.otsuThresholdPrediction
import org.janelia.saalfeldlab.util.*
import java.util.concurrent.CancellationException
import java.util.concurrent.ForkJoinPool
import java.util.concurrent.LinkedBlockingQueue
import kotlin.collections.List
import kotlin.collections.MutableList
Expand Down Expand Up @@ -207,7 +204,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
originalWritableVolatileBackingImage = field?.volatileViewerImg?.writableSource
}

private var predictionTask: UtilityTask<Unit>? = null
private var predictionJob : Job = Job().apply { complete() }

internal val lastPredictionProperty = SimpleObjectProperty<SamTaskInfo?>(null)
var lastPrediction by lastPredictionProperty.nullable()
Expand Down Expand Up @@ -299,8 +296,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
internal fun cleanup() {
clearPromptDrawings()
currentLabelToPaint = Label.INVALID
predictionTask?.cancel()
predictionTask = null
predictionJob.cancel()
if (unwrapResult) {
if (!maskProvided) {
maskedSource?.resetMasks()
Expand Down Expand Up @@ -833,8 +829,8 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
}

open fun requestPrediction(promptPoints: List<SamPoint>, estimateThreshold: Boolean = true) {
if (predictionTask == null || predictionTask?.isCancelled == true) {
startPredictionTask()
if (!predictionJob.isActive) {
startPredictionJob()
}
currentPredictionRequest = SamPredictor.points(promptPoints.toList()) to estimateThreshold
}
Expand All @@ -849,25 +845,25 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
private var embeddingRequest: Deferred<OnnxTensor>? = null

private var currentPrediction: SamPredictor.SamPrediction? = null
private fun startPredictionTask() {
private fun startPredictionJob() {
val maskSource = maskedSource ?: return
val task = Tasks.createTask { task ->
predictionJob = SAM_TASK_SCOPE.launch {
val session = createOrtSessionTask.get()
val imageEmbedding = try {
runBlocking {
isBusy = true
embeddingRequest!!.await()
}
} catch (e: InterruptedException) {
if (!task.isCancelled) throw e
return@createTask
if (coroutineContext.isActive) throw e
return@launch
} catch (e: CancellationException) {
return@createTask
return@launch
} finally {
isBusy = false
}
val predictor = SamPredictor(ortEnv, session, imageEmbedding, imgWidth to imgHeight)
while (!task.isCancelled) {
while (coroutineContext.isActive) {
val predictionPair = predictionQueue.take()
val (predictionRequest, estimateThreshold) = predictionPair
val points = (predictionRequest as SparsePrediction).points
Expand Down Expand Up @@ -933,7 +929,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
} catch (e: InterruptedException) {
System.setErr(stdErr)
LOG.debug(e) { "Connected Components Interrupted During SAM" }
task.cancel()
cancel("Connected Components Interrupted During SAM" )
continue
} finally {
System.setErr(stdErr)
Expand Down Expand Up @@ -1048,8 +1044,6 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
lastPrediction = SamTaskInfo(maskSource, predictionIntervalInViewerSpace, imageEmbedding, predictionRequest)
}
}
predictionTask = task
task.submit(SAM_TASK_SERVICE)
}

private fun setBestEstimatedThreshold(interval: Interval? = null) {
Expand Down Expand Up @@ -1140,15 +1134,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*

private val LOG = KotlinLogging.logger { }

internal val SAM_TASK_SERVICE = ForkJoinPool.ForkJoinWorkerThreadFactory { pool: ForkJoinPool ->
val worker = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool)
worker.isDaemon = true
worker.priority = 4
worker.name = "sam-task-" + worker.poolIndex
worker
}.let { factory ->
ForkJoinPool(max(4.0, 4 * (Runtime.getRuntime().availableProcessors()).toDouble()).toInt(), factory, null, false)
}
internal val SAM_TASK_SCOPE = CoroutineScope(Dispatchers.IO + Job())


private fun calculateTargetScreenScaleFactor(viewer: ViewerPanelFX): Double {
Expand Down
Loading

0 comments on commit 687e4f9

Please sign in to comment.