From ad5d6586622ad328f40b561795f4faf91a1ce90a Mon Sep 17 00:00:00 2001 From: Jerjou Cheng Date: Wed, 4 Jan 2017 10:04:09 -0800 Subject: [PATCH] Stream audio from microphone for speech streaming --- speech/grpc/pom.xml | 6 + .../speech/StreamingRecognizeClient.java | 150 ++++++++++-------- .../speech/StreamingRecognizeClientTest.java | 106 ++++++++++--- 3 files changed, 175 insertions(+), 87 deletions(-) diff --git a/speech/grpc/pom.xml b/speech/grpc/pom.xml index 71d4536f246..ca857bb323b 100644 --- a/speech/grpc/pom.xml +++ b/speech/grpc/pom.xml @@ -156,6 +156,12 @@ limitations under the License. 0.31 test + + org.mockito + mockito-all + 1.10.19 + test + io.grpc grpc-auth diff --git a/speech/grpc/src/main/java/com/examples/cloud/speech/StreamingRecognizeClient.java b/speech/grpc/src/main/java/com/examples/cloud/speech/StreamingRecognizeClient.java index 03d029ec1fe..e7f17e1d29a 100644 --- a/speech/grpc/src/main/java/com/examples/cloud/speech/StreamingRecognizeClient.java +++ b/speech/grpc/src/main/java/com/examples/cloud/speech/StreamingRecognizeClient.java @@ -23,10 +23,10 @@ import com.google.cloud.speech.v1beta1.RecognitionConfig.AudioEncoding; import com.google.cloud.speech.v1beta1.SpeechGrpc; import com.google.cloud.speech.v1beta1.StreamingRecognitionConfig; +import com.google.cloud.speech.v1beta1.StreamingRecognitionResult; import com.google.cloud.speech.v1beta1.StreamingRecognizeRequest; import com.google.cloud.speech.v1beta1.StreamingRecognizeResponse; import com.google.protobuf.ByteString; -import com.google.protobuf.TextFormat; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; @@ -44,14 +44,17 @@ import org.apache.log4j.Logger; import org.apache.log4j.SimpleLayout; -import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import javax.sound.sampled.AudioFormat; +import javax.sound.sampled.AudioSystem; +import javax.sound.sampled.DataLine; +import javax.sound.sampled.LineUnavailableException; +import javax.sound.sampled.TargetDataLine; /** @@ -59,37 +62,36 @@ */ public class StreamingRecognizeClient { - private final String file; - private final int samplingRate; - private static final Logger logger = Logger.getLogger(StreamingRecognizeClient.class.getName()); private final ManagedChannel channel; - private final SpeechGrpc.SpeechStub speechClient; - - private static final int BYTES_PER_BUFFER = 3200; //buffer size in bytes - private static final int BYTES_PER_SAMPLE = 2; //bytes per sample for LINEAR16 - private static final List OAUTH2_SCOPES = Arrays.asList("https://www.googleapis.com/auth/cloud-platform"); + static final int BYTES_PER_SAMPLE = 2; // bytes per sample for LINEAR16 + + private final int samplingRate; + final int bytesPerBuffer; // buffer size in bytes + + // Used for testing + protected TargetDataLine mockDataLine = null; + /** * Construct client connecting to Cloud Speech server at {@code host:port}. */ - public StreamingRecognizeClient(ManagedChannel channel, String file, int samplingRate) + public StreamingRecognizeClient(ManagedChannel channel, int samplingRate) throws IOException { - this.file = file; this.samplingRate = samplingRate; this.channel = channel; + this.bytesPerBuffer = samplingRate * BYTES_PER_SAMPLE / 10; // 100 ms speechClient = SpeechGrpc.newStub(channel); // Send log4j logs to Console // If you are going to run this on GCE, you might wish to integrate with - // google-cloud-java logging. See: + // google-cloud-java logging. See: // https://github.com/GoogleCloudPlatform/google-cloud-java/blob/master/README.md#stackdriver-logging-alpha - ConsoleAppender appender = new ConsoleAppender(new SimpleLayout(), SYSTEM_OUT); logger.addAppender(appender); } @@ -109,19 +111,73 @@ static ManagedChannel createChannel(String host, int port) throws IOException { return channel; } + /** + * Return a Line to the audio input device. + */ + private TargetDataLine getAudioInputLine() { + // For testing + if (null != mockDataLine) { + return mockDataLine; + } + + AudioFormat format = new AudioFormat(samplingRate, BYTES_PER_SAMPLE * 8, 1, true, false); + DataLine.Info info = new DataLine.Info(TargetDataLine.class, format); + if (!AudioSystem.isLineSupported(info)) { + throw new RuntimeException(String.format( + "Device doesn't support LINEAR16 mono raw audio format at {}Hz", samplingRate)); + } + try { + TargetDataLine line = (TargetDataLine) AudioSystem.getLine(info); + // Make sure the line buffer doesn't overflow while we're filling this thread's buffer. + line.open(format, bytesPerBuffer * 5); + return line; + } catch (LineUnavailableException e) { + throw new RuntimeException(e); + } + } + /** Send streaming recognize requests to server. */ public void recognize() throws InterruptedException, IOException { final CountDownLatch finishLatch = new CountDownLatch(1); StreamObserver responseObserver = new StreamObserver() { + private int sentenceLength = 1; + /** + * Prints the transcription results. Interim results are overwritten by subsequent + * results, until a final one is returned, at which point we start a new line. + * + * Flags the program to exit when it hears "exit". + */ @Override public void onNext(StreamingRecognizeResponse response) { - logger.info("Received response: " + TextFormat.printToString(response)); + List results = response.getResultsList(); + if (results.size() < 1) { + return; + } + + StreamingRecognitionResult result = results.get(0); + String transcript = result.getAlternatives(0).getTranscript(); + + // Print interim results with a line feed, so subsequent transcriptions will overwrite + // it. Final result will print a newline. + String format = "%-" + this.sentenceLength + 's'; + if (result.getIsFinal()) { + format += '\n'; + this.sentenceLength = 1; + + if (transcript.toLowerCase().indexOf("exit") >= 0) { + finishLatch.countDown(); + } + } else { + format += '\r'; + this.sentenceLength = transcript.length(); + } + System.out.print(String.format(format, transcript)); } @Override public void onError(Throwable error) { - logger.log(Level.WARN, "recognize failed: {0}", error); + logger.log(Level.ERROR, "recognize failed: {0}", error); finishLatch.countDown(); } @@ -146,33 +202,28 @@ public void onCompleted() { StreamingRecognitionConfig.newBuilder() .setConfig(config) .setInterimResults(true) - .setSingleUtterance(true) + .setSingleUtterance(false) .build(); StreamingRecognizeRequest initial = StreamingRecognizeRequest.newBuilder().setStreamingConfig(streamingConfig).build(); requestObserver.onNext(initial); - // Open audio file. Read and send sequential buffers of audio as additional RecognizeRequests. - FileInputStream in = new FileInputStream(new File(file)); - // For LINEAR16 at 16000 Hz sample rate, 3200 bytes corresponds to 100 milliseconds of audio. - byte[] buffer = new byte[BYTES_PER_BUFFER]; + // Get a Line to the audio input device. + TargetDataLine in = getAudioInputLine(); + byte[] buffer = new byte[bytesPerBuffer]; int bytesRead; - int totalBytes = 0; - int samplesPerBuffer = BYTES_PER_BUFFER / BYTES_PER_SAMPLE; - int samplesPerMillis = samplingRate / 1000; - while ((bytesRead = in.read(buffer)) != -1) { - totalBytes += bytesRead; + in.start(); + // Read and send sequential buffers of audio as additional RecognizeRequests. + while (finishLatch.getCount() > 0 + && (bytesRead = in.read(buffer, 0, buffer.length)) != -1) { StreamingRecognizeRequest request = StreamingRecognizeRequest.newBuilder() .setAudioContent(ByteString.copyFrom(buffer, 0, bytesRead)) .build(); requestObserver.onNext(request); - // To simulate real-time audio, sleep after sending each audio buffer. - Thread.sleep(samplesPerBuffer / samplesPerMillis); } - logger.info("Sent " + totalBytes + " bytes from audio file: " + file); } catch (RuntimeException e) { // Cancel RPC. requestObserver.onError(e); @@ -187,21 +238,13 @@ public void onCompleted() { public static void main(String[] args) throws Exception { - String audioFile = ""; - String host = "speech.googleapis.com"; - Integer port = 443; - Integer sampling = 16000; + String host = null; + Integer port = null; + Integer sampling = null; CommandLineParser parser = new DefaultParser(); Options options = new Options(); - options.addOption( - Option.builder() - .longOpt("file") - .desc("path to audio file") - .hasArg() - .argName("FILE_PATH") - .build()); options.addOption( Option.builder() .longOpt("host") @@ -226,31 +269,14 @@ public static void main(String[] args) throws Exception { try { CommandLine line = parser.parse(options, args); - if (line.hasOption("file")) { - audioFile = line.getOptionValue("file"); - } else { - System.err.println("An Audio file must be specified (e.g. /foo/baz.raw)."); - System.exit(1); - } - - if (line.hasOption("host")) { - host = line.getOptionValue("host"); - } else { - System.err.println("An API enpoint must be specified (typically speech.googleapis.com)."); - System.exit(1); - } - if (line.hasOption("port")) { - port = Integer.parseInt(line.getOptionValue("port")); - } else { - System.err.println("An SSL port must be specified (typically 443)."); - System.exit(1); - } + host = line.getOptionValue("host", "speech.googleapis.com"); + port = Integer.parseInt(line.getOptionValue("port", "443")); if (line.hasOption("sampling")) { sampling = Integer.parseInt(line.getOptionValue("sampling")); } else { - System.err.println("An Audio sampling rate must be specified."); + System.err.println("An Audio sampling rate (--sampling) must be specified. (e.g. 16000)"); System.exit(1); } } catch (ParseException exp) { @@ -259,7 +285,7 @@ public static void main(String[] args) throws Exception { } ManagedChannel channel = createChannel(host, port); - StreamingRecognizeClient client = new StreamingRecognizeClient(channel, audioFile, sampling); + StreamingRecognizeClient client = new StreamingRecognizeClient(channel, sampling); try { client.recognize(); } finally { diff --git a/speech/grpc/src/test/java/com/examples/cloud/speech/StreamingRecognizeClientTest.java b/speech/grpc/src/test/java/com/examples/cloud/speech/StreamingRecognizeClientTest.java index 773af36c707..7ed15a0fe9e 100644 --- a/speech/grpc/src/test/java/com/examples/cloud/speech/StreamingRecognizeClientTest.java +++ b/speech/grpc/src/test/java/com/examples/cloud/speech/StreamingRecognizeClientTest.java @@ -17,24 +17,26 @@ package com.examples.cloud.speech; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.when; import io.grpc.ManagedChannel; -import org.apache.log4j.Logger; -import org.apache.log4j.SimpleLayout; -import org.apache.log4j.WriterAppender; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; -import java.io.File; +import java.io.ByteArrayOutputStream; +import java.io.FileInputStream; import java.io.IOException; -import java.io.StringWriter; -import java.io.Writer; -import java.net.URI; -import java.nio.file.Path; -import java.nio.file.Paths; +import java.io.PrintStream; +import javax.sound.sampled.TargetDataLine; /** @@ -42,46 +44,100 @@ */ @RunWith(JUnit4.class) public class StreamingRecognizeClientTest { - private Writer writer; - private WriterAppender appender; + private final ByteArrayOutputStream stdout = new ByteArrayOutputStream(); + private static final PrintStream REAL_OUT = System.out; + + @Mock private TargetDataLine mockDataLine; @Before public void setUp() { - writer = new StringWriter(); - appender = new WriterAppender(new SimpleLayout(), writer); - Logger.getRootLogger().addAppender(appender); + MockitoAnnotations.initMocks(this); + System.setOut(new PrintStream(stdout)); } @After public void tearDown() { - Logger.getRootLogger().removeAppender(appender); + System.setOut(REAL_OUT); } @Test public void test16KHzAudio() throws InterruptedException, IOException { - URI uri = new File("resources/audio.raw").toURI(); - Path path = Paths.get(uri); - String host = "speech.googleapis.com"; int port = 443; ManagedChannel channel = StreamingRecognizeClient.createChannel(host, port); - StreamingRecognizeClient client = new StreamingRecognizeClient(channel, path.toString(), 16000); + + final FileInputStream in = new FileInputStream("resources/audio.raw"); + + final int samplingRate = 16000; + final StreamingRecognizeClient client = new StreamingRecognizeClient(channel, samplingRate); + + // When audio data is requested from the mock, get it from the file + when(mockDataLine.read(any(byte[].class), anyInt(), anyInt())).thenAnswer(new Answer() { + public Object answer(InvocationOnMock invocation) { + Object[] args = invocation.getArguments(); + byte[] buffer = (byte[])args[0]; + int offset = (int)args[1]; + int len = (int)args[2]; + assertThat(buffer.length).isEqualTo(len); + + try { + // Sleep, to simulate realtime + int samplesPerBuffer = client.bytesPerBuffer / StreamingRecognizeClient.BYTES_PER_SAMPLE; + int samplesPerMillis = samplingRate / 1000; + Thread.sleep(samplesPerBuffer / samplesPerMillis); + + // Provide the audio bytes from the file + return in.read(buffer, offset, len); + + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }); + client.mockDataLine = mockDataLine; client.recognize(); - assertThat(writer.toString()).contains("transcript: \"how old is the Brooklyn Bridge\""); + + assertThat(stdout.toString()).contains("how old is the Brooklyn Bridge"); } @Test public void test32KHzAudio() throws InterruptedException, IOException { - URI uri = new File("resources/audio32KHz.raw").toURI(); - Path path = Paths.get(uri); - String host = "speech.googleapis.com"; int port = 443; ManagedChannel channel = StreamingRecognizeClient.createChannel(host, port); - StreamingRecognizeClient client = new StreamingRecognizeClient(channel, path.toString(), 32000); + + final FileInputStream in = new FileInputStream("resources/audio32KHz.raw"); + + final int samplingRate = 32000; + final StreamingRecognizeClient client = new StreamingRecognizeClient(channel, samplingRate); + + // When audio data is requested from the mock, get it from the file + when(mockDataLine.read(any(byte[].class), anyInt(), anyInt())).thenAnswer(new Answer() { + public Object answer(InvocationOnMock invocation) { + Object[] args = invocation.getArguments(); + byte[] buffer = (byte[])args[0]; + int offset = (int)args[1]; + int len = (int)args[2]; + assertThat(buffer.length).isEqualTo(len); + + try { + // Sleep, to simulate realtime + int samplesPerBuffer = client.bytesPerBuffer / StreamingRecognizeClient.BYTES_PER_SAMPLE; + int samplesPerMillis = samplingRate / 1000; + Thread.sleep(samplesPerBuffer / samplesPerMillis); + + // Provide the audio bytes from the file + return in.read(buffer, offset, len); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }); + client.mockDataLine = mockDataLine; client.recognize(); - assertThat(writer.toString()).contains("transcript: \"how old is the Brooklyn Bridge\""); + + assertThat(stdout.toString()).contains("how old is the Brooklyn Bridge"); } }