diff --git a/CMakeLists.txt b/CMakeLists.txt index 2136b3f..d0635dd 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,9 @@ add_subdirectory(MLModelRunner) add_subdirectory(SerDes) add_subdirectory(test) +add_custom_target(copy) +add_custom_command(TARGET copy PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy CompilerInterface ${CMAKE_CURRENT_BINARY_DIR}/MLModelRunner/CompilerInterface) + if(LLVM_MLBRIDGE) include(AddLLVM) include(HandleLLVMOptions) @@ -51,6 +54,9 @@ if(LLVM_MLBRIDGE) ADDITIONAL_HEADER_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include + DEPENDS + copy + LINK_LIBS ModelRunnerLib $ @@ -89,4 +95,4 @@ else() endif(LLVM_MLBRIDGE) install(DIRECTORY include/ DESTINATION include) -install(DIRECTORY CompilerInterface DESTINATION include/python/MLCompilerBridge) +install(DIRECTORY CompilerInterface DESTINATION MLModelRunner/CompilerInterface) diff --git a/MLModelRunner/gRPCModelRunner/CMakeLists.txt b/MLModelRunner/gRPCModelRunner/CMakeLists.txt index d9fb94a..6ce52d6 100755 --- a/MLModelRunner/gRPCModelRunner/CMakeLists.txt +++ b/MLModelRunner/gRPCModelRunner/CMakeLists.txt @@ -84,7 +84,7 @@ if(LLVM_MLBRIDGE) ${proto_python_srcs_list} ) else() - add_library(gRPCModelRunnerLib OBJECT + add_library(gRPCModelRunnerLib ${cc_files} ${proto_srcs_list} ${grpc_srcs_list} diff --git a/SerDes/protobufSerDes.cpp b/SerDes/protobufSerDes.cpp index 9990b7e..38c68a8 100644 --- a/SerDes/protobufSerDes.cpp +++ b/SerDes/protobufSerDes.cpp @@ -156,6 +156,13 @@ void *ProtobufSerDes::deserializeUntyped(void *data) { this->MessageLength = ref.size() * sizeof(int32_t); return ret->data(); } + if (field->type() == FieldDescriptor::Type::TYPE_INT64) { + auto &ref = reflection->GetRepeatedField(*Response, field); + std::vector *ret = + new std::vector(ref.begin(), ref.end()); + this->MessageLength = ref.size() * sizeof(int64_t); + return ret->data(); + } if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) { auto ref = reflection->GetRepeatedField(*Response, field); std::vector *ret = new std::vector(ref.begin(), ref.end()); @@ -199,6 +206,12 @@ void *ProtobufSerDes::deserializeUntyped(void *data) { this->MessageLength = sizeof(int32_t); return ptr; } + if (field->type() == FieldDescriptor::Type::TYPE_INT64) { + int64_t value = reflection->GetInt64(*Response, field); + int64_t *ptr = new int64_t(value); + this->MessageLength = sizeof(int64_t); + return ptr; + } if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) { float value = reflection->GetFloat(*Response, field); float *ptr = new float(value); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index dccdc25..f4a83a3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -6,5 +6,6 @@ file(GLOB MODEL_OBJECTS ${CMAKE_CURRENT_SOURCE_DIR}/tf_models/*.o) foreach(MODEL_OBJECT ${MODEL_OBJECTS}) target_link_libraries(MLBridgeCPPTest PRIVATE ${MODEL_OBJECT}) endforeach() -target_link_libraries(MLBridgeCPPTest PRIVATE ModelRunnerUtils) -target_include_directories(MLBridgeCPPTest PRIVATE ${CMAKE_BINARY_DIR}/include ${TENSORFLOW_AOT_PATH}/include) +target_link_libraries(MLBridgeCPPTest PRIVATE MLCompilerBridge ) +target_include_directories(MLBridgeCPPTest PRIVATE ${CMAKE_BINARY_DIR}/include ${TENSORFLOW_AOT_PATH}/include ${CMAKE_CURRENT_SOURCE_DIR}/include) +target_link_libraries(MLBridgeCPPTest PRIVATE tf_xla_runtime) diff --git a/test/MLBridgeTest.cpp b/test/MLBridgeTest.cpp index 794b1ee..f0847d2 100644 --- a/test/MLBridgeTest.cpp +++ b/test/MLBridgeTest.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "HelloMLBridge_Env.h" #include "MLModelRunner/MLModelRunner.h" #include "MLModelRunner/ONNXModelRunner/ONNXModelRunner.h" #include "MLModelRunner/PipeModelRunner.h" @@ -13,12 +14,8 @@ #include "MLModelRunner/Utils/DataTypes.h" #include "MLModelRunner/Utils/MLConfig.h" #include "MLModelRunner/gRPCModelRunner.h" -// #include "grpc/helloMLBridgeTest/helloMLBridgeTest.grpc.pb.h" -// #include "grpc/helloMLBridgeTest/helloMLBridgeTest.pb.h" -#include "grpcpp/impl/codegen/status.h" -#include "inference/HelloMLBridge_Env.h" +#include "ProtosInclude.h" #include "llvm/Support/CommandLine.h" -// #include "llvm/Support/raw_ostream.h" #include #include #include @@ -31,6 +28,20 @@ #define debug_out \ if (!silent) \ std::cout +using namespace grpc; + +#define gRPCModelRunnerInit(datatype) \ + increment_port(1); \ + MLBridgeTestgRPC_##datatype::Reply response; \ + MLBridgeTestgRPC_##datatype::Request request; \ + MLRunner = std::make_unique< \ + gRPCModelRunner>( \ + server_address, &request, &response, nullptr); \ + MLRunner->setRequest(&request); \ + MLRunner->setResponse(&response) static llvm::cl::opt cl_server_address("test-server-address", llvm::cl::Hidden, @@ -55,7 +66,6 @@ std::string basename; BaseSerDes::Kind SerDesType; std::string test_config; -std::string data_format; std::string pipe_name; std::string server_address; @@ -65,6 +75,7 @@ void testPrimitive(std::string label, T1 value, T2 expected) { std::pair p("request_" + label, value); MLRunner->populateFeatures(p); T2 out = MLRunner->evaluate(); + debug_out << " " << label << " reply: " << out << "\n"; if (std::abs(out - expected) > 10e-6) { std::cerr << "Error: Expected " << label << " reply: " << expected @@ -96,73 +107,117 @@ void testVector(std::string label, std::vector value, debug_out << "\n"; } -void runTests() { - if (data_format != "json") { - testPrimitive("int", 11, 12); - testPrimitive("long", 1234567890, 1234567891); - testPrimitive("float", 3.14, 4.14); - testPrimitive("double", 0.123456789123456789, - 1.123456789123456789); - testPrimitive("char", 'a', 'b'); - testPrimitive("bool", true, false); - testVector("vec_int", {11, 22, 33}, {12, 23, 34}); - testVector("vec_long", {123456780, 222, 333}, - {123456780, 123456781, 123456782}); - testVector("vec_float", {11.1, 22.2, 33.3}, - {1.11, 2.22, -3.33, 0}); - testVector("vec_double", - {-1.1111111111, -2.2222222222, -3.3333333333}, - {1.12345678912345670, -1.12345678912345671}); - } else if (data_format == "json") { - testPrimitive("int", 11, 12); - testPrimitive("long", 1234567890, 12345); - testPrimitive("float", 3.14, 4.14); - testPrimitive("double", 0.123456789123456789, - 1.123456789123456789); - testPrimitive("char", 'a', 'b'); - testPrimitive("bool", true, false); - testVector("vec_int", {11, 22, 33}, {12, 23, 34}); - testVector("vec_long", {123456780, 222, 333}, - {6780, 6781, 6782}); - testVector("vec_float", {11.1, 22.2, 33.3}, - {1.11, 2.22, -3.33, 0}); - testVector("vec_double", - {-1.1111111111, -2.2222222222, -3.3333333333}, - {1.12345678912345670, -1.12345678912345671}); +int testPipeBytes() { + if (pipe_name == "") { + std::cerr + << "Pipe name must be specified via --test-pipe-name=\n"; + exit(1); } + basename = "./" + pipe_name; + SerDesType = BaseSerDes::Kind::Bitstream; + MLRunner = std::make_unique( + basename + ".out", basename + ".in", SerDesType, nullptr); + testPrimitive("int", 11, 12); + testPrimitive("long", 1234567890, 1234567891); + testPrimitive("float", 3.14, 4.14); + testPrimitive("double", 0.123456789123456789, + 1.123456789123456789); + testPrimitive("char", 'a', 'b'); + testPrimitive("bool", true, false); + testVector("vec_int", {11, 22, 33}, {12, 23, 34}); + testVector("vec_long", {123456780, 222, 333}, + {123456780, 123456781, 123456782}); + testVector("vec_float", {11.1, 22.2, 33.3}, + {1.11, 2.22, -3.33, 0}); + testVector("vec_double", + {-1.1111111111, -2.2222222222, -3.3333333333}, + {1.12345678912345670, -1.12345678912345671}); + return 0; } -int testPipes() { +int testPipeJSON() { if (pipe_name == "") { std::cerr << "Pipe name must be specified via --test-pipe-name=\n"; exit(1); } - basename = "/tmp/" + pipe_name; - if (data_format == "json") - SerDesType = BaseSerDes::Kind::Json; - else if (data_format == "protobuf") - SerDesType = BaseSerDes::Kind::Protobuf; - else if (data_format == "bytes") - SerDesType = BaseSerDes::Kind::Bitstream; - else { - std::cout << "Invalid data format\n"; - exit(1); - } - + basename = "./" + pipe_name; + SerDesType = BaseSerDes::Kind::Json; MLRunner = std::make_unique( basename + ".out", basename + ".in", SerDesType, nullptr); - - runTests(); + testPrimitive("int", 11, 12); + testPrimitive("long", 1234567890, 12345); + testPrimitive("float", 3.14, 4.14); + testPrimitive("double", 0.123456789123456789, + 1.123456789123456789); + testPrimitive("char", 'a', 'b'); + testPrimitive("bool", true, false); + testVector("vec_int", {11, 22, 33}, {12, 23, 34}); + testVector("vec_long", {123456780, 222, 333}, + {6780, 6781, 6782}); + testVector("vec_float", {11.1, 22.2, 33.3}, + {1.11, 2.22, -3.33, 0}); + testVector("vec_double", + {-1.1111111111, -2.2222222222, -3.3333333333}, + {1.12345678912345670, -1.12345678912345671}); return 0; } +void increment_port(int delta) { + int split = server_address.find(":"); + int port = stoi(server_address.substr(split + 1)); + server_address = + server_address.substr(0, split) + ":" + to_string(port + delta); +} + int testGRPC() { if (server_address == "") { std::cerr << "Server Address must be specified via " - "--test-server-address=:\n"; + "--test-server-address=\":\"\n"; exit(1); } + { + gRPCModelRunnerInit(int); + testPrimitive("int", 11, 12); + } + { + gRPCModelRunnerInit(long); + testPrimitive("long", 1234567890, 1234567891); + } + { + gRPCModelRunnerInit(float); + testPrimitive("float", 3.14, 4.14); + } + { + gRPCModelRunnerInit(double); + testPrimitive("double", 0.123456789123456789, + 1.123456789123456789); + } + increment_port(1); + { + gRPCModelRunnerInit(bool); + testPrimitive("bool", true, false); + } + { + gRPCModelRunnerInit(vec_int); + testVector("vec_int", {11, 22, 33}, {12, 23, 34}); + } + { + gRPCModelRunnerInit(vec_long); + testVector("vec_long", {123456780, 222, 333}, + {123456780, 123456781, 123456782}); + } + { + gRPCModelRunnerInit(vec_float); + testVector("vec_float", {11.1, 22.2, 33.3}, + {1.11, 2.22, -3.33, 0}); + } + { + gRPCModelRunnerInit(vec_double); + testVector("vec_double", + {-1.1111111111, -2.2222222222, -3.3333333333}, + {1.12345678912345670, -1.12345678912345671}); + } return 0; } @@ -176,12 +231,10 @@ int main(int argc, char **argv) { if (test_config == "pipe-bytes") { pipe_name = cl_pipe_name.getValue(); - data_format = "bytes"; - testPipes(); + testPipeBytes(); } else if (test_config == "pipe-json") { pipe_name = cl_pipe_name.getValue(); - data_format = "json"; - testPipes(); + testPipeJSON(); } else if (test_config == "grpc") { server_address = cl_server_address.getValue(); testGRPC(); diff --git a/test/include/HelloMLBridge_Env.h b/test/include/HelloMLBridge_Env.h new file mode 100644 index 0000000..9ee3743 --- /dev/null +++ b/test/include/HelloMLBridge_Env.h @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// +// Part of the MLCompilerBridge Project, under the Apache 2.0 License. +// See the LICENSE file under home directory for license and copyright +// information. +// +//===----------------------------------------------------------------------===// + +#include "MLModelRunner/ONNXModelRunner/environment.h" +#include "MLModelRunner/ONNXModelRunner/utils.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/raw_ostream.h" + +using namespace MLBridge; +class MLBridgeTestEnv : public Environment { + Observation CurrObs; + +public: + MLBridgeTestEnv() { setNextAgent("agent"); }; + Observation &reset() override; + Observation &step(Action) override; + +protected: + std::vector FeatureVector; +}; + +Observation &MLBridgeTestEnv::step(Action Action) { + CurrObs.clear(); + std::copy(FeatureVector.begin(), FeatureVector.end(), + std::back_inserter(CurrObs)); + llvm::outs() << "Action: " << Action << "\n"; + setDone(); + return CurrObs; +} + +Observation &MLBridgeTestEnv::reset() { + std::copy(FeatureVector.begin(), FeatureVector.end(), + std::back_inserter(CurrObs)); + return CurrObs; +} diff --git a/test/include/ProtosInclude.h b/test/include/ProtosInclude.h new file mode 100644 index 0000000..72b4241 --- /dev/null +++ b/test/include/ProtosInclude.h @@ -0,0 +1,20 @@ +#include "grpc/MLBridgeTest_bool/MLBridgeTest_bool.grpc.pb.h" +#include "grpc/MLBridgeTest_bool/MLBridgeTest_bool.pb.h" +#include "grpc/MLBridgeTest_char/MLBridgeTest_char.grpc.pb.h" +#include "grpc/MLBridgeTest_char/MLBridgeTest_char.pb.h" +#include "grpc/MLBridgeTest_double/MLBridgeTest_double.grpc.pb.h" +#include "grpc/MLBridgeTest_double/MLBridgeTest_double.pb.h" +#include "grpc/MLBridgeTest_float/MLBridgeTest_float.grpc.pb.h" +#include "grpc/MLBridgeTest_float/MLBridgeTest_float.pb.h" +#include "grpc/MLBridgeTest_int/MLBridgeTest_int.grpc.pb.h" +#include "grpc/MLBridgeTest_int/MLBridgeTest_int.pb.h" +#include "grpc/MLBridgeTest_long/MLBridgeTest_long.grpc.pb.h" +#include "grpc/MLBridgeTest_long/MLBridgeTest_long.pb.h" +#include "grpc/MLBridgeTest_vec_double/MLBridgeTest_vec_double.grpc.pb.h" +#include "grpc/MLBridgeTest_vec_double/MLBridgeTest_vec_double.pb.h" +#include "grpc/MLBridgeTest_vec_float/MLBridgeTest_vec_float.grpc.pb.h" +#include "grpc/MLBridgeTest_vec_float/MLBridgeTest_vec_float.pb.h" +#include "grpc/MLBridgeTest_vec_int/MLBridgeTest_vec_int.grpc.pb.h" +#include "grpc/MLBridgeTest_vec_int/MLBridgeTest_vec_int.pb.h" +#include "grpc/MLBridgeTest_vec_long/MLBridgeTest_vec_long.grpc.pb.h" +#include "grpc/MLBridgeTest_vec_long/MLBridgeTest_vec_long.pb.h" diff --git a/test/mlbridge-test.py b/test/mlbridge-test.py index 4cb8466..38c712d 100644 --- a/test/mlbridge-test.py +++ b/test/mlbridge-test.py @@ -5,22 +5,26 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # # ------------------------------------------------------------------------------ - import argparse import numpy as np import ctypes import sys +import os import torch, torch.nn as nn - -sys.path.append("../CompilerInterface") +import subprocess +import time + +BUILD_DIR = "../build_release" +sys.path.extend( + [ + "../CompilerInterface", + f"{BUILD_DIR}/MLModelRunner/gRPCModelRunner/Python-Utilities", + ] +) from PipeCompilerInterface import PipeCompilerInterface from GrpcCompilerInterface import GrpcCompilerInterface -sys.path.append("./Python-Utilities") -import helloMLBridge_pb2, helloMLBridge_pb2_grpc, grpc -from concurrent import futures - FAIL = 1 SUCCESS = 0 @@ -53,10 +57,86 @@ "--server_port", type=int, help="Server Port", - default=5050, +) +parser.add_argument( + "--test_number", + type=int, + help="Datatype number for test", + default=0, ) args = parser.parse_args() +if args.test_number <= 1: + import MLBridgeTest_int_pb2, MLBridgeTest_int_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_int_pb2, + MLBridgeTest_int_pb2_grpc, + ) +elif args.test_number == 2: + import MLBridgeTest_long_pb2, MLBridgeTest_long_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_long_pb2, + MLBridgeTest_long_pb2_grpc, + ) +elif args.test_number == 3: + import MLBridgeTest_float_pb2, MLBridgeTest_float_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_float_pb2, + MLBridgeTest_float_pb2_grpc, + ) +elif args.test_number == 4: + import MLBridgeTest_double_pb2, MLBridgeTest_double_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_double_pb2, + MLBridgeTest_double_pb2_grpc, + ) +elif args.test_number == 5: + import MLBridgeTest_char_pb2, MLBridgeTest_char_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_char_pb2, + MLBridgeTest_char_pb2_grpc, + ) +elif args.test_number == 6: + import MLBridgeTest_bool_pb2, MLBridgeTest_bool_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_bool_pb2, + MLBridgeTest_bool_pb2_grpc, + ) +elif args.test_number == 7: + import MLBridgeTest_vec_int_pb2, MLBridgeTest_vec_int_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_vec_int_pb2, + MLBridgeTest_vec_int_pb2_grpc, + ) +elif args.test_number == 8: + import MLBridgeTest_vec_long_pb2, MLBridgeTest_vec_long_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_vec_long_pb2, + MLBridgeTest_vec_long_pb2_grpc, + ) +elif args.test_number == 9: + import MLBridgeTest_vec_float_pb2, MLBridgeTest_vec_float_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_vec_float_pb2, + MLBridgeTest_vec_float_pb2_grpc, + ) +elif args.test_number == 10: + import MLBridgeTest_vec_double_pb2, MLBridgeTest_vec_double_pb2_grpc + + MLBridgeTest_pb2, MLBridgeTest_pb2_grpc = ( + MLBridgeTest_vec_double_pb2, + MLBridgeTest_vec_double_pb2_grpc, + ) + class DummyModel(nn.Module): def __init__(self, input_dim=10): @@ -96,34 +176,77 @@ def forward(self, input): returned_data = { 1: 12, - 2: ctypes.c_long(1234567891), + 2: 1234567891, 3: 4.14, - 4: ctypes.c_double(1.123456789123456789), + 4: 1.123456789123456789, 5: ord("b"), 6: False, 7: [12, 23, 34], - 8: [ctypes.c_long(123456780), ctypes.c_long(123456781), ctypes.c_long(123456782)], + 8: [123456780, 123456781, 123456782], 9: [1.11, 2.22, -3.33, 0], - 10: [ctypes.c_double(1.12345678912345670), ctypes.c_double(-1.12345678912345671)], + 10: [1.12345678912345670, -1.12345678912345671], } -# may not be configured for extended types -if args.data_format == "json": - returned_data[2] = ctypes.c_long(12345) - returned_data[8] = [ - ctypes.c_long(6780), - ctypes.c_long(6781), - ctypes.c_long(6782), - ] # [ctypes.c_long(6780),ctypes.c_long(6781),ctypes.c_long(6782)], +if args.use_pipe and args.data_format == "bytes": + returned_data.update( + { + 2: ctypes.c_long(1234567891), + 4: ctypes.c_double(1.123456789123456789), + 8: [ + ctypes.c_long(123456780), + ctypes.c_long(123456781), + ctypes.c_long(123456782), + ], + 10: [ + ctypes.c_double(1.12345678912345670), + ctypes.c_double(-1.12345678912345671), + ], + } + ) +if args.use_pipe and args.data_format == "json": + returned_data.update( + { + 2: ctypes.c_long(12345), + 4: ctypes.c_double(1.123456789123456789), + 8: [ctypes.c_long(6780), ctypes.c_long(6781), ctypes.c_long(6782)], + 10: [ + ctypes.c_double(1.12345678912345670), + ctypes.c_double(-1.12345678912345671), + ], + } + ) + +status = SUCCESS + +# test index vs received data +def checkData(index, data): + global status + if not args.silent: + print(" ", expected_type[index], "request:", data) + + if isinstance(expected_data[index], list): + for e, d in zip(expected_data[index], data): + if abs(e - d) > 10e-6: + print( + f"Error: Expected {expected_type[index]} request: {expected_data[index]}, Received: {data}" + ) + status = FAIL + # raise Exception(f"Mismatch in {expected_type[i]}") + + elif abs(data - expected_data[index]) > 10e-6: + print( + f"Error: Expected {expected_type[index]} request: {expected_data[index]}, Received: {data}" + ) + status = FAIL + # raise Exception(f"Mismatch in {expected_type[i]}") def run_pipe_communication(data_format, pipe_name): - compiler_interface = PipeCompilerInterface(data_format, "/tmp/" + pipe_name) + compiler_interface = PipeCompilerInterface(data_format, "./" + pipe_name) if not args.silent: print("PipeCompilerInterface init...") compiler_interface.reset_pipes() - status = SUCCESS i = 0 while True: i += 1 @@ -137,24 +260,7 @@ def run_pipe_communication(data_format, pipe_name): if len(data) == 1: data = data[0] - if not args.silent: - print(" ", expected_type[i], "request:", data) - - if isinstance(expected_data[i], list): - for e, d in zip(expected_data[i], data): - if abs(e - d) > 10e-6: - print( - f"Error: Expected {expected_type[i]} request: {expected_data[i]}, Received: {data}" - ) - status = FAIL - # raise Exception(f"Mismatch in {expected_type[i]}") - - elif abs(data - expected_data[i]) > 10e-6: - print( - f"Error: Expected {expected_type[i]} request: {expected_data[i]}, Received: {data}" - ) - status = FAIL - # raise Exception(f"Mismatch in {expected_type[i]}") + checkData(i, data) compiler_interface.populate_buffer(returned_data[i]) @@ -166,46 +272,56 @@ def run_pipe_communication(data_format, pipe_name): compiler_interface.reset_pipes() -class service_server(helloMLBridge_pb2_grpc.HelloMLBridgeService): - def __init__(self, data_format, pipe_name): - # self.serdes = SerDes.SerDes(data_format, pipe_name) - # self.serdes.init() +class service_server(MLBridgeTest_pb2_grpc.MLBridgeTestService): + def __init__(self): pass def getAdvice(self, request, context): try: - print(request) - print("Entered getAdvice") - print("Data: ", request.tensor) - reply = helloMLBridge_pb2.ActionRequest(action=1) + request_type = [var for var in dir(request) if "request" in var] + data = getattr(request, request_type[0]) + checkData(args.test_number, data) + if status == FAIL: + os.system("touch mlbridge-grpc-fail.txt") + reply = MLBridgeTest_pb2.Reply(action=returned_data[args.test_number]) return reply except: - reply = helloMLBridge_pb2.ActionRequest(action=-1) + reply = MLBridgeTest_pb2.Reply(action=-1) return reply -def test_func(): - data = 3.24 - import struct - - print(data, type(data)) - byte_data = struct.pack("f", data) - print(byte_data, len(byte_data)) - - print("decoding...") - decoded = float(byte_data) +def run_grpc_communication(): + # parent with test_number 0 spawns different servers + if args.test_number == 0: + process_list = [] + for i in range(1, len(expected_type) + 1): + p = subprocess.Popen( + f"python mlbridge-test.py --use_grpc --server_port={args.server_port} --silent={args.silent} --test_number={i}".split(), + ) + process_list.append(p) + + time.sleep(10) + global status + for p in process_list: + if os.path.isfile("mlbridge-grpc-fail.txt"): + status = FAIL + os.system("rm mlbridge-grpc-fail.txt") + p.terminate() + exit(status) - print(decoded, type(decoded)) + # servers serve different datatypes + else: + compiler_interface = GrpcCompilerInterface( + mode="server", + add_server_method=MLBridgeTest_pb2_grpc.add_MLBridgeTestServiceServicer_to_server, + grpc_service_obj=service_server(), + hostport=args.server_port + args.test_number, + ) + compiler_interface.start_server() if __name__ == "__main__": if args.use_pipe: run_pipe_communication(args.data_format, args.pipe_name) elif args.use_grpc: - compiler_interface = GrpcCompilerInterface( - mode="server", - add_server_method=helloMLBridge_pb2_grpc.add_HelloMLBridgeServiceServicer_to_server, - grpc_service_obj=service_server(), - hostport=args.server_port, - ) - compiler_interface.start_server() + run_grpc_communication() diff --git a/test/mlbridge-test.sh b/test/mlbridge-test.sh index 9ac1d49..a491f51 100644 --- a/test/mlbridge-test.sh +++ b/test/mlbridge-test.sh @@ -52,11 +52,11 @@ python $SERVER_FILE --use_pipe=True --data_format=json --pipe_name=mlbridgepipe2 SERVER_PID=$! run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=pipe-json --test-pipe-name=mlbridgepipe2 --silent -exit $STATUS - -# python $SERVER_FILE --use_grpc --server_port=50065 & -# echo "Test [grpc]:" -# run_test $BUILD_DIR/MLCompilerBridgeTest --test-config=grpc --test-server-address="0.0.0.0:50065" +echo -e "${BLUE}${BOLD}Testing MLBridge [grpc]${NC}" +python $SERVER_FILE --use_grpc --server_port=50155 --silent=True & +SERVER_PID=$! +run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=grpc --test-server-address="0.0.0.0:50155" --silent +exit $STATUS # echo "Test [onnx]:" # $BUILD_DIR/MLCompilerBridgeTest --test-config=onnx diff --git a/test/protos/MLBridgeTest_bool.proto b/test/protos/MLBridgeTest_bool.proto index 614294a..3cea3b7 100644 --- a/test/protos/MLBridgeTest_bool.proto +++ b/test/protos/MLBridgeTest_bool.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { bool data = 1; } +message Request { bool request_bool = 1; } message Reply { bool action = 1; } diff --git a/test/protos/MLBridgeTest_char.proto b/test/protos/MLBridgeTest_char.proto index d164cbd..4982760 100644 --- a/test/protos/MLBridgeTest_char.proto +++ b/test/protos/MLBridgeTest_char.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { string data = 1; } +message Request { string request_char = 1; } message Reply { string action = 1; } diff --git a/test/protos/MLBridgeTest_double.proto b/test/protos/MLBridgeTest_double.proto index d7528ee..8ad7c13 100644 --- a/test/protos/MLBridgeTest_double.proto +++ b/test/protos/MLBridgeTest_double.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { double data = 1; } +message Request { double request_double = 1; } message Reply { double action = 1; } diff --git a/test/protos/MLBridgeTest_float.proto b/test/protos/MLBridgeTest_float.proto index de88aed..d933fdf 100644 --- a/test/protos/MLBridgeTest_float.proto +++ b/test/protos/MLBridgeTest_float.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { float data = 1; } +message Request { float request_float = 1; } message Reply { float action = 1; } diff --git a/test/protos/MLBridgeTest_int.proto b/test/protos/MLBridgeTest_int.proto index 2314844..d21ff81 100644 --- a/test/protos/MLBridgeTest_int.proto +++ b/test/protos/MLBridgeTest_int.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { int32 data = 1; } +message Request { int32 request_int = 1; } message Reply { int32 action = 1; } diff --git a/test/protos/MLBridgeTest_long.proto b/test/protos/MLBridgeTest_long.proto index 6bbfce8..41a1419 100644 --- a/test/protos/MLBridgeTest_long.proto +++ b/test/protos/MLBridgeTest_long.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { int64 data = 1; } +message Request { int64 request_long = 1; } message Reply { int64 action = 1; } diff --git a/test/protos/MLBridgeTest_vec_double.proto b/test/protos/MLBridgeTest_vec_double.proto index 669512d..7911efd 100644 --- a/test/protos/MLBridgeTest_vec_double.proto +++ b/test/protos/MLBridgeTest_vec_double.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { repeated double data = 1; } +message Request { repeated double request_vec_double = 1; } message Reply { repeated double action = 1; } diff --git a/test/protos/MLBridgeTest_vec_float.proto b/test/protos/MLBridgeTest_vec_float.proto index c6da760..9bcc675 100644 --- a/test/protos/MLBridgeTest_vec_float.proto +++ b/test/protos/MLBridgeTest_vec_float.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { repeated float data = 1; } +message Request { repeated float request_vec_float = 1; } message Reply { repeated float action = 1; } diff --git a/test/protos/MLBridgeTest_vec_int.proto b/test/protos/MLBridgeTest_vec_int.proto index 9bb741b..2fb2298 100644 --- a/test/protos/MLBridgeTest_vec_int.proto +++ b/test/protos/MLBridgeTest_vec_int.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { repeated int32 data = 1; } +message Request { repeated int32 request_vec_int = 1; } message Reply { repeated int32 action = 1; } diff --git a/test/protos/MLBridgeTest_vec_long.proto b/test/protos/MLBridgeTest_vec_long.proto index d82bf92..9757d1d 100644 --- a/test/protos/MLBridgeTest_vec_long.proto +++ b/test/protos/MLBridgeTest_vec_long.proto @@ -6,5 +6,5 @@ service MLBridgeTestService { rpc getAdvice(Request) returns (Reply) {} } -message Request { repeated int64 data = 1; } +message Request { repeated int64 request_vec_long = 1; } message Reply { repeated int64 action = 1; }