From b6cf04941f29ecb4c44ab10ba1c4df3e8bccbaec Mon Sep 17 00:00:00 2001 From: Thomas Date: Sun, 25 Feb 2024 18:05:38 +0900 Subject: [PATCH] THRIFT-5762 Expose service result objects in Java Some libraries want to bypass the TServer class and handle the full service startup manually. For example when building a service that hosts multiple thrift services where the IFace type is unknown when handling a request. For example when you host multiple services on top of netty and through an HTTP path you want to route to the correct thrift service. In this situation you treat can treat an IFace as an Object and use the `getProcessMapView()` method to parse a byte array into a thrift message and pass let the `AsyncProcessFunction` handle the invocation. To return a correct thrift response it's necessary to write the `{service_name}_result` that contains the response args. While it is possible to get an incoming args object from the (Async)ProcessFunction its unfortunately not possible to get a result object without using reflection. This PR extends the (Async)ProcessFunction by adding a `getEmptyResultInstance` method that returns a new generic `A` (answer) that matches the `{service_name}_result` object. This allows thrift users to write the following processing code: ```java void handleRequest( TProtocol in, TProtocol out, TBaseAsyncProcessor processor, I asyncIface ) throws TException { final Map, TBase, TBase>> processMap = (Map) processor.getProcessMapView(); final var msg = in.readMessageBegin(); final var fn = processMap.get(msg.name); final var args = fn.getEmptyArgsInstance(); args.read(in); in.readMessageEnd(); if (fn.isOneway()) { return; } fn.start(asyncIface, args, new AsyncMethodCallback<>() { @Override public void onComplete(TBase o) { try { out.writeMessageBegin(new TMessage(fn.getMethodName(), TMessageType.REPLY, msg.getSeqid())); final var response_result = fn.getEmptyResultInstance(); final var success_field = response_result.fieldForId(SUCCESS_ID); ((TBase) response_result).setFieldValue(success_field, o); response_result.write(out); out.writeMessageEnd(); out.getTransport().flush(); } catch (TException e) { throw new RuntimeException(e); } } @Override public void onError(Exception e) { try { out.writeMessageBegin(new TMessage(fn.getMethodName(), TMessageType.EXCEPTION, msg.getSeqid())); ((TApplicationException) e).write(out); out.writeMessageEnd(); out.getTransport().flush(); } catch (TException ex) { throw new RuntimeException(ex); } } }); } ``` The above example code doesn't need any reference to the original types and can dynamically create the correct objects to return a correct response. --- .../src/thrift/generate/t_java_generator.cc | 51 +++++-- .../apache/thrift/AsyncProcessFunction.java | 6 +- .../org/apache/thrift/ProcessFunction.java | 141 +++++++++--------- .../apache/thrift/TBaseAsyncProcessor.java | 6 +- .../org/apache/thrift/TBaseProcessor.java | 6 +- .../thrift/server/TSaslNonblockingServer.java | 2 +- 6 files changed, 120 insertions(+), 92 deletions(-) diff --git a/compiler/cpp/src/thrift/generate/t_java_generator.cc b/compiler/cpp/src/thrift/generate/t_java_generator.cc index d7e0b65939c..1985a3d8494 100644 --- a/compiler/cpp/src/thrift/generate/t_java_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_java_generator.cc @@ -3635,22 +3635,23 @@ void t_java_generator::generate_service_server(t_service* tservice) { indent(f_service_) << "public Processor(I iface) {" << endl; indent(f_service_) << " super(iface, getProcessMap(new java.util.HashMap>()));" + "org.apache.thrift.TBase, ? extends org.apache.thrift.TBase>>()));" << endl; indent(f_service_) << "}" << endl << endl; indent(f_service_) << "protected Processor(I iface, java.util.Map> " - "processMap) {" + "org.apache.thrift.ProcessFunction> processMap) {" << endl; indent(f_service_) << " super(iface, getProcessMap(processMap));" << endl; indent(f_service_) << "}" << endl << endl; - indent(f_service_) << "private static java.util.Map> " + indent(f_service_) << "private static java.util.Map> " "getProcessMap(java.util.Map> processMap) {" + " org.apache.thrift.TBase, ? extends org.apache.thrift.TBase>> processMap) {" << endl; indent_up(); for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { @@ -3702,13 +3703,13 @@ void t_java_generator::generate_service_async_server(t_service* tservice) { indent(f_service_) << "public AsyncProcessor(I iface) {" << endl; indent(f_service_) << " super(iface, getProcessMap(new java.util.HashMap>()));" + "org.apache.thrift.TBase, ?, ? extends org.apache.thrift.TBase>>()));" << endl; indent(f_service_) << "}" << endl << endl; indent(f_service_) << "protected AsyncProcessor(I iface, java.util.Map> processMap) {" + "org.apache.thrift.TBase, ?, ? extends org.apache.thrift.TBase>> processMap) {" << endl; indent(f_service_) << " super(iface, getProcessMap(processMap));" << endl; indent(f_service_) << "}" << endl << endl; @@ -3716,9 +3717,9 @@ void t_java_generator::generate_service_async_server(t_service* tservice) { indent(f_service_) << "private static java.util.Map> getProcessMap(java.util.Map> getProcessMap(java.util.Map> processMap) {" + "org.apache.thrift.TBase, ?, ? extends org.apache.thrift.TBase>> processMap) {" << endl; indent_up(); for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { @@ -3783,13 +3784,23 @@ void t_java_generator::generate_process_async_function(t_service* tservice, t_fu // Open class indent(f_service_) << "public static class " << make_valid_java_identifier(tfunction->get_name()) << " extends org.apache.thrift.AsyncProcessFunction {" << endl; + << argsname << ", " << resulttype << ", " << resultname << "> {" << endl; indent_up(); indent(f_service_) << "public " << make_valid_java_identifier(tfunction->get_name()) << "() {" << endl; indent(f_service_) << " super(\"" << tfunction->get_name() << "\");" << endl; indent(f_service_) << "}" << endl << endl; + indent(f_service_) << java_override_annotation() << endl; + indent(f_service_) << "public " << resultname << " getEmptyResultInstance() {" << endl; + if (tfunction->is_oneway()) { + indent(f_service_) << " return null;" << endl; + } + else { + indent(f_service_) << " return new " << resultname << "();" << endl; + } + indent(f_service_) << "}" << endl << endl; + indent(f_service_) << java_override_annotation() << endl; indent(f_service_) << "public " << argsname << " getEmptyArgsInstance() {" << endl; indent(f_service_) << " return new " << argsname << "();" << endl; @@ -3931,7 +3942,7 @@ void t_java_generator::generate_process_async_function(t_service* tservice, t_fu indent(f_service_) << "}" << endl << endl; indent(f_service_) << java_override_annotation() << endl; - indent(f_service_) << "protected boolean isOneway() {" << endl; + indent(f_service_) << "public boolean isOneway() {" << endl; indent(f_service_) << " return " << ((tfunction->is_oneway()) ? "true" : "false") << ";" << endl; indent(f_service_) << "}" << endl << endl; @@ -3989,7 +4000,7 @@ void t_java_generator::generate_process_function(t_service* tservice, t_function // Open class indent(f_service_) << "public static class " << make_valid_java_identifier(tfunction->get_name()) << " extends org.apache.thrift.ProcessFunction {" << endl; + << argsname << ", " << resultname << "> {" << endl; indent_up(); indent(f_service_) << "public " << make_valid_java_identifier(tfunction->get_name()) << "() {" << endl; @@ -4002,7 +4013,7 @@ void t_java_generator::generate_process_function(t_service* tservice, t_function indent(f_service_) << "}" << endl << endl; indent(f_service_) << java_override_annotation() << endl; - indent(f_service_) << "protected boolean isOneway() {" << endl; + indent(f_service_) << "public boolean isOneway() {" << endl; indent(f_service_) << " return " << ((tfunction->is_oneway()) ? "true" : "false") << ";" << endl; indent(f_service_) << "}" << endl << endl; @@ -4012,12 +4023,22 @@ void t_java_generator::generate_process_function(t_service* tservice, t_function << endl; indent(f_service_) << "}" << endl << endl; + indent(f_service_) << java_override_annotation() << endl; + indent(f_service_) << "public " << resultname << " getEmptyResultInstance() {" << endl; + if (tfunction->is_oneway()) { + indent(f_service_) << " return null;" << endl; + } + else { + indent(f_service_) << " return new " << resultname << "();" << endl; + } + indent(f_service_) << "}" << endl << endl; + indent(f_service_) << java_override_annotation() << endl; indent(f_service_) << "public " << resultname << " getResult(I iface, " << argsname << " args) throws org.apache.thrift.TException {" << endl; indent_up(); if (!tfunction->is_oneway()) { - indent(f_service_) << resultname << " result = new " << resultname << "();" << endl; + indent(f_service_) << resultname << " result = getEmptyResultInstance();" << endl; } t_struct* xs = tfunction->get_xceptions(); diff --git a/lib/java/src/main/java/org/apache/thrift/AsyncProcessFunction.java b/lib/java/src/main/java/org/apache/thrift/AsyncProcessFunction.java index c7c4be3036d..4e65ae66e17 100644 --- a/lib/java/src/main/java/org/apache/thrift/AsyncProcessFunction.java +++ b/lib/java/src/main/java/org/apache/thrift/AsyncProcessFunction.java @@ -23,20 +23,22 @@ import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.server.AbstractNonblockingServer; -public abstract class AsyncProcessFunction { +public abstract class AsyncProcessFunction { final String methodName; public AsyncProcessFunction(String methodName) { this.methodName = methodName; } - protected abstract boolean isOneway(); + public abstract boolean isOneway(); public abstract void start(I iface, T args, AsyncMethodCallback resultHandler) throws TException; public abstract T getEmptyArgsInstance(); + public abstract A getEmptyResultInstance(); + public abstract AsyncMethodCallback getResultHandler( final AbstractNonblockingServer.AsyncFrameBuffer fb, int seqid); diff --git a/lib/java/src/main/java/org/apache/thrift/ProcessFunction.java b/lib/java/src/main/java/org/apache/thrift/ProcessFunction.java index 7399342a217..8552863aaef 100644 --- a/lib/java/src/main/java/org/apache/thrift/ProcessFunction.java +++ b/lib/java/src/main/java/org/apache/thrift/ProcessFunction.java @@ -8,86 +8,91 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public abstract class ProcessFunction { - private final String methodName; +public abstract class ProcessFunction { + private final String methodName; - private static final Logger LOGGER = LoggerFactory.getLogger(ProcessFunction.class.getName()); + private static final Logger LOGGER = LoggerFactory.getLogger(ProcessFunction.class.getName()); - public ProcessFunction(String methodName) { - this.methodName = methodName; - } - - public final void process(int seqid, TProtocol iprot, TProtocol oprot, I iface) - throws TException { - T args = getEmptyArgsInstance(); - try { - args.read(iprot); - } catch (TProtocolException e) { - iprot.readMessageEnd(); - TApplicationException x = - new TApplicationException(TApplicationException.PROTOCOL_ERROR, e.getMessage()); - oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.EXCEPTION, seqid)); - x.write(oprot); - oprot.writeMessageEnd(); - oprot.getTransport().flush(); - return; + public ProcessFunction(String methodName) { + this.methodName = methodName; } - iprot.readMessageEnd(); - TSerializable result = null; - byte msgType = TMessageType.REPLY; - try { - result = getResult(iface, args); - } catch (TTransportException ex) { - LOGGER.error("Transport error while processing " + getMethodName(), ex); - throw ex; - } catch (TApplicationException ex) { - LOGGER.error("Internal application error processing " + getMethodName(), ex); - result = ex; - msgType = TMessageType.EXCEPTION; - } catch (Exception ex) { - LOGGER.error("Internal error processing " + getMethodName(), ex); - if (rethrowUnhandledExceptions()) throw new RuntimeException(ex.getMessage(), ex); - if (!isOneway()) { - result = - new TApplicationException( - TApplicationException.INTERNAL_ERROR, - "Internal error processing " + getMethodName()); - msgType = TMessageType.EXCEPTION; - } + public final void process(int seqid, TProtocol iprot, TProtocol oprot, I iface) + throws TException { + T args = getEmptyArgsInstance(); + try { + args.read(iprot); + } catch (TProtocolException e) { + iprot.readMessageEnd(); + TApplicationException x = + new TApplicationException(TApplicationException.PROTOCOL_ERROR, e.getMessage()); + oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.EXCEPTION, seqid)); + x.write(oprot); + oprot.writeMessageEnd(); + oprot.getTransport().flush(); + return; + } + iprot.readMessageEnd(); + TSerializable result = null; + byte msgType = TMessageType.REPLY; + + try { + result = getResult(iface, args); + } catch (TTransportException ex) { + LOGGER.error("Transport error while processing " + getMethodName(), ex); + throw ex; + } catch (TApplicationException ex) { + LOGGER.error("Internal application error processing " + getMethodName(), ex); + result = ex; + msgType = TMessageType.EXCEPTION; + } catch (Exception ex) { + LOGGER.error("Internal error processing " + getMethodName(), ex); + if (rethrowUnhandledExceptions()) throw new RuntimeException(ex.getMessage(), ex); + if (!isOneway()) { + result = + new TApplicationException( + TApplicationException.INTERNAL_ERROR, + "Internal error processing " + getMethodName()); + msgType = TMessageType.EXCEPTION; + } + } + + if (!isOneway()) { + oprot.writeMessageBegin(new TMessage(getMethodName(), msgType, seqid)); + result.write(oprot); + oprot.writeMessageEnd(); + oprot.getTransport().flush(); + } } - if (!isOneway()) { - oprot.writeMessageBegin(new TMessage(getMethodName(), msgType, seqid)); - result.write(oprot); - oprot.writeMessageEnd(); - oprot.getTransport().flush(); + private void handleException(int seqid, TProtocol oprot) throws TException { + if (!isOneway()) { + TApplicationException x = + new TApplicationException( + TApplicationException.INTERNAL_ERROR, "Internal error processing " + getMethodName()); + oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.EXCEPTION, seqid)); + x.write(oprot); + oprot.writeMessageEnd(); + oprot.getTransport().flush(); + } } - } - private void handleException(int seqid, TProtocol oprot) throws TException { - if (!isOneway()) { - TApplicationException x = - new TApplicationException( - TApplicationException.INTERNAL_ERROR, "Internal error processing " + getMethodName()); - oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.EXCEPTION, seqid)); - x.write(oprot); - oprot.writeMessageEnd(); - oprot.getTransport().flush(); + protected boolean rethrowUnhandledExceptions() { + return false; } - } - protected boolean rethrowUnhandledExceptions() { - return false; - } + public abstract boolean isOneway(); - protected abstract boolean isOneway(); + public abstract TBase getResult(I iface, T args) throws TException; - public abstract TBase getResult(I iface, T args) throws TException; + public abstract T getEmptyArgsInstance(); - public abstract T getEmptyArgsInstance(); + /** + * Returns null when this is a oneWay function. + */ + public abstract A getEmptyResultInstance(); - public String getMethodName() { - return methodName; - } + public String getMethodName() { + return methodName; + } } diff --git a/lib/java/src/main/java/org/apache/thrift/TBaseAsyncProcessor.java b/lib/java/src/main/java/org/apache/thrift/TBaseAsyncProcessor.java index 266f0c0ceec..0a583c05a6a 100644 --- a/lib/java/src/main/java/org/apache/thrift/TBaseAsyncProcessor.java +++ b/lib/java/src/main/java/org/apache/thrift/TBaseAsyncProcessor.java @@ -30,15 +30,15 @@ public class TBaseAsyncProcessor implements TAsyncProcessor, TProcessor { protected final Logger LOGGER = LoggerFactory.getLogger(getClass().getName()); final I iface; - final Map> processMap; + final Map> processMap; public TBaseAsyncProcessor( - I iface, Map> processMap) { + I iface, Map> processMap) { this.iface = iface; this.processMap = processMap; } - public Map> getProcessMapView() { + public Map> getProcessMapView() { return Collections.unmodifiableMap(processMap); } diff --git a/lib/java/src/main/java/org/apache/thrift/TBaseProcessor.java b/lib/java/src/main/java/org/apache/thrift/TBaseProcessor.java index 05cd7b8ccda..ff1ccfcc9c0 100644 --- a/lib/java/src/main/java/org/apache/thrift/TBaseProcessor.java +++ b/lib/java/src/main/java/org/apache/thrift/TBaseProcessor.java @@ -10,15 +10,15 @@ public abstract class TBaseProcessor implements TProcessor { private final I iface; - private final Map> processMap; + private final Map> processMap; protected TBaseProcessor( - I iface, Map> processFunctionMap) { + I iface, Map> processFunctionMap) { this.iface = iface; this.processMap = processFunctionMap; } - public Map> getProcessMapView() { + public Map> getProcessMapView() { return Collections.unmodifiableMap(processMap); } diff --git a/lib/java/src/main/java/org/apache/thrift/server/TSaslNonblockingServer.java b/lib/java/src/main/java/org/apache/thrift/server/TSaslNonblockingServer.java index 6f22d8bb454..8c899d56cdb 100644 --- a/lib/java/src/main/java/org/apache/thrift/server/TSaslNonblockingServer.java +++ b/lib/java/src/main/java/org/apache/thrift/server/TSaslNonblockingServer.java @@ -255,7 +255,7 @@ private void handleIO() { } else if (selected.isWritable()) { saslHandler.handleWrite(); } else { - LOGGER.error("Invalid intrest op " + selected.interestOps()); + LOGGER.error("Invalid interest op " + selected.interestOps()); closeChannel(selected); continue; }