Skip to content

[SPARK-49133][CORE] Make member MemoryConsumer#used atomic to avoid user code causing deadlock #51849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.memory;

import java.io.IOException;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.spark.errors.SparkCoreErrors;
import org.apache.spark.unsafe.array.LongArray;
Expand All @@ -33,7 +34,7 @@ public abstract class MemoryConsumer {
protected final TaskMemoryManager taskMemoryManager;
private final long pageSize;
private final MemoryMode mode;
protected long used;
protected final AtomicLong used = new AtomicLong(0L);

protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize, MemoryMode mode) {
this.taskMemoryManager = taskMemoryManager;
Expand All @@ -56,7 +57,7 @@ public MemoryMode getMode() {
* Returns the size of used memory in bytes.
*/
public long getUsed() {
return used;
return used.get();
}

/**
Expand Down Expand Up @@ -97,7 +98,7 @@ public LongArray allocateArray(long size) {
if (page == null || page.size() < required) {
throwOom(page, required);
}
used += required;
used.getAndAdd(required);
return new LongArray(page);
}

Expand All @@ -118,15 +119,15 @@ protected MemoryBlock allocatePage(long required) {
if (page == null || page.size() < required) {
throwOom(page, required);
}
used += page.size();
used.getAndAdd(page.size());
return page;
}

/**
* Free a memory block.
*/
protected void freePage(MemoryBlock page) {
used -= page.size();
used.getAndAdd(-page.size());
taskMemoryManager.freePage(page, this);
}

Expand All @@ -135,7 +136,7 @@ protected void freePage(MemoryBlock page) {
*/
public long acquireMemory(long size) {
long granted = taskMemoryManager.acquireExecutionMemory(size, this);
used += granted;
used.getAndAdd(granted);
return granted;
}

Expand All @@ -144,7 +145,7 @@ public long acquireMemory(long size) {
*/
public void freeMemory(long size) {
taskMemoryManager.releaseExecutionMemory(size, this);
used -= size;
used.getAndAdd(-size);
}

private void throwOom(final MemoryBlock page, final long required) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {

public void use(long size) {
long got = taskMemoryManager.acquireExecutionMemory(size, this);
used += got;
used.getAndAdd(got);
}

public void free(long size) {
used -= size;
used.getAndAdd(-size);
taskMemoryManager.releaseExecutionMemory(size, this);
}

@VisibleForTesting
public void freePage(MemoryBlock page) {
used -= page.size();
used.getAndAdd(-page.size());
taskMemoryManager.freePage(page, this);
}
}
Expand Down