Skip to content

Commit

Permalink
Close the stdin pipe handle when the respective OutputStream is closed.
Browse files Browse the repository at this point in the history
Otherwise, a subprocess that consumes the stdin in its entirety can never terminate.

PiperOrigin-RevId: 635774207
Change-Id: I134ddd1fee50faccb8ddb8400dbb11ce6a354c05
  • Loading branch information
tjgq authored and copybara-github committed May 21, 2024
1 parent daa61a7 commit b3f5c62
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ public void write(int b) throws IOException {

@Override
public void write(byte[] b, int off, int len) throws IOException {
writeStream(b, off, len);
writeStdin(b, off, len);
}

@Override
public void close() {
closeStdin();
}
}

Expand Down Expand Up @@ -113,12 +118,6 @@ public synchronized void close() {
nativeStream = WindowsProcesses.INVALID;
}
}

@Override
protected void finalize() throws Throwable {
close();
super.finalize();
}
}

private static final AtomicInteger THREAD_SEQUENCE_NUMBER = new AtomicInteger(1);
Expand All @@ -138,7 +137,7 @@ public Thread newThread(Runnable runnable) {
});

private volatile long nativeProcess;
private final OutputStream stdinStream;
private final ProcessOutputStream stdinStream;
private final ProcessInputStream stdoutStream;
private final ProcessInputStream stderrStream;
private final Future<WaitResult> processFuture;
Expand Down Expand Up @@ -247,7 +246,10 @@ public void waitFor() throws InterruptedException {
@Override
public synchronized void close() {
if (nativeProcess != WindowsProcesses.INVALID) {
// stdoutStream and stderrStream are null if they are redirected to files.
// The streams are null when redirected from/to files.
if (stdinStream != null) {
stdinStream.close();
}
if (stdoutStream != null) {
stdoutStream.close();
}
Expand Down Expand Up @@ -275,7 +277,7 @@ public InputStream getErrorStream() {
return stderrStream;
}

private synchronized void writeStream(byte[] b, int off, int len) throws IOException {
private synchronized void writeStdin(byte[] b, int off, int len) throws IOException {
checkLiveness();

int remaining = len;
Expand All @@ -293,6 +295,11 @@ private synchronized void writeStream(byte[] b, int off, int len) throws IOExcep
}
}

private synchronized void closeStdin() {
checkLiveness();
WindowsProcesses.closeStdin(nativeProcess);
}

private void checkLiveness() {
if (nativeProcess == WindowsProcesses.INVALID) {
throw new IllegalStateException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,13 @@ public static long createProcess(
*
* <p>Blocks until either some data was written or the process is terminated.
*
* @return the number of bytes written
* @return the number of bytes written, or -1 if an error occurs.
*/
public static native int writeStdin(long process, byte[] bytes, int offset, int length);

/** Closes the stdin of the specified process. */
public static native void closeStdin(long process);

/** Returns an opaque identifier of stdout stream for the process. */
public static native long getStdout(long process);

Expand Down
8 changes: 7 additions & 1 deletion src/main/native/windows/processes-jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,13 @@ Java_com_google_devtools_build_lib_windows_WindowsProcesses_writeStdin(
return process->WriteStdin(env, java_bytes, offset, length);
}

extern "C" JNIEXPORT void JNICALL
Java_com_google_devtools_build_lib_windows_WindowsProcesses_closeStdin(
JNIEnv* env, jclass clazz, jlong process_long) {
NativeProcess* process = reinterpret_cast<NativeProcess*>(process_long);
process->CloseStdin();
}

extern "C" JNIEXPORT jlong JNICALL
Java_com_google_devtools_build_lib_windows_WindowsProcesses_getStdout(
JNIEnv* env, jclass clazz, jlong process_long) {
Expand Down Expand Up @@ -519,7 +526,6 @@ Java_com_google_devtools_build_lib_windows_WindowsProcesses_waitFor(
JNIEnv* env, jclass clazz, jlong process_long, jlong java_timeout) {
NativeProcess* process = reinterpret_cast<NativeProcess*>(process_long);
int res = process->WaitFor(static_cast<int64_t>(java_timeout));
process->CloseStdin();
return static_cast<jint>(res);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,14 @@ public static void main(String[] args) throws Exception {
case 'I':
char register = arg.charAt(1);
int length = Integer.parseInt(arg.substring(2));
byte[] buf = new byte[length];
byte[] buf;
if (length > 0) {
buf = new byte[length];
System.in.read(buf, 0, length);
} else {
buf = System.in.readAllBytes();
}
registers.put(register, buf);
System.in.read(buf, 0, length);
break;

case 'E':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,23 +124,23 @@ public void testQuotesArgWithDoubleQuote() throws Exception {
}

@Test
public void testSmoke() throws Exception {
public void testOneShot() throws Exception {
process =
WindowsProcesses.createProcess(
mockBinary, mockArgs("Ia5", "Oa"), null, null, null, null);
WindowsProcesses.createProcess(mockBinary, mockArgs("Ia0", "Oa"), null, null, null, null);
assertNoProcessError();

byte[] input = "HELLO".getBytes(UTF_8);
byte[] output = new byte[5];
WindowsProcesses.writeStdin(process, input, 0, 5);
assertThat(WindowsProcesses.writeStdin(process, input, 0, 5)).isEqualTo(5);
WindowsProcesses.closeStdin(process);
assertNoProcessError();
readStdout(output, 0, 5);
assertNoStreamError(WindowsProcesses.getStdout(process));
assertThat(new String(output, UTF_8)).isEqualTo("HELLO");
}

@Test
public void testPingpong() throws Exception {
public void testChunks() throws Exception {
List<String> args = new ArrayList<>();
for (int i = 0; i < 100; i++) {
args.add("Ia3");
Expand All @@ -154,6 +154,7 @@ public void testPingpong() throws Exception {
byte[] input = String.format("%03d", i).getBytes(UTF_8);
assertThat(input.length).isEqualTo(3);
assertThat(WindowsProcesses.writeStdin(process, input, 0, 3)).isEqualTo(3);
assertNoProcessError();
byte[] output = new byte[3];
assertThat(readStdout(output, 0, 3)).isEqualTo(3);
assertThat(Integer.parseInt(new String(output, UTF_8))).isEqualTo(i);
Expand Down Expand Up @@ -267,8 +268,7 @@ public void testArrayOutOfBounds() throws Exception {
@Test
public void testOffsetedOps() throws Exception {
process =
WindowsProcesses.createProcess(
mockBinary, mockArgs("Ia3", "Oa"), null, null, null, null);
WindowsProcesses.createProcess(mockBinary, mockArgs("Ia3", "Oa"), null, null, null, null);
byte[] input = "01234".getBytes(UTF_8);
byte[] output = "abcde".getBytes(UTF_8);

Expand Down

0 comments on commit b3f5c62

Please sign in to comment.