Skip to content

Commit

Permalink
Added ONNX test
Browse files Browse the repository at this point in the history
  • Loading branch information
RajivChitale committed Feb 18, 2024
1 parent 5417576 commit f75b196
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 54 deletions.
32 changes: 28 additions & 4 deletions test/MLBridgeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,30 @@ int testGRPC() {
return 0;
}

int testONNX() { return 0; }
class ONNXTest : public MLBridgeTestEnv {
public:
int run(int expectedAction) {
Agent *agent = new Agent("/home/cs21btech11051/ml-llvm-project/"
"MLCompilerBridge/test/onnx/dummy_model.onnx");
FeatureVector.clear();
int n = 100;
for (int i = 0; i < n; i++) {
float delta = (float)(i - expectedAction) / n;
FeatureVector.push_back(delta * delta);
}

std::map<std::string, Agent *> agents;
agents["agent"] = agent;
MLRunner = std::make_unique<ONNXModelRunner>(this, agents, nullptr);
MLRunner->evaluate<int>();
if (lastAction != expectedAction) {
std::cerr << "Error: Expected action: " << expectedAction
<< ", Computed action: " << lastAction << "\n";
exit(1);
}
return 0;
}
};

} // namespace

Expand All @@ -236,9 +259,10 @@ int main(int argc, char **argv) {
} else if (test_config == "grpc") {
server_address = cl_server_address.getValue();
testGRPC();
} else if (test_config == "onnx")
testONNX();
else
} else if (test_config == "onnx") {
ONNXTest t;
t.run(20);
} else
std::cerr << "--test-config must be provided from [pipe-bytes, pipe-json, "
"grpc, onnx]\n";
return 0;
Expand Down
3 changes: 2 additions & 1 deletion test/include/HelloMLBridge_Env.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class MLBridgeTestEnv : public Environment {
MLBridgeTestEnv() { setNextAgent("agent"); };
Observation &reset() override;
Observation &step(Action) override;
Action lastAction;

protected:
std::vector<float> FeatureVector;
Expand All @@ -28,7 +29,7 @@ Observation &MLBridgeTestEnv::step(Action Action) {
CurrObs.clear();
std::copy(FeatureVector.begin(), FeatureVector.end(),
std::back_inserter(CurrObs));
llvm::outs() << "Action: " << Action << "\n";
lastAction = Action;
setDone();
return CurrObs;
}
Expand Down
40 changes: 0 additions & 40 deletions test/inference/HelloMLBridge_Env.h

This file was deleted.

33 changes: 26 additions & 7 deletions test/mlbridge-test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
import os
import torch, torch.nn as nn
import torch.onnx
import subprocess
import time

Expand All @@ -29,9 +30,7 @@
SUCCESS = 0

parser = argparse.ArgumentParser()
parser.add_argument(
"--use_pipe", type=bool, default=False, help="Use pipe or not", required=False
)
parser.add_argument("--use_pipe", default=False, help="Use pipe or not", required=False)
parser.add_argument(
"--data_format",
type=str,
Expand Down Expand Up @@ -64,6 +63,12 @@
help="Datatype number for test",
default=0,
)
parser.add_argument(
"--export_onnx",
help="Export onnx test model",
required=False,
default=False,
)
args = parser.parse_args()

if args.test_number <= 1:
Expand Down Expand Up @@ -139,13 +144,25 @@


class DummyModel(nn.Module):
def __init__(self, input_dim=10):
def __init__(self):
nn.Module.__init__(self)
self.fc1 = nn.Linear(input_dim, 1)

def forward(self, input):
x = self.fc1(input)
return x
return 2 - input


def export_onnx_model(input_dim=100):
onnx_filename = "./onnx/dummy_model.onnx"
dummy_value = torch.randn(1, input_dim)
torch.onnx.export(
DummyModel(),
dummy_value,
onnx_filename,
input_names=["obs"],
verbose=True,
export_params=True,
)
print(f"Model exported to {onnx_filename}")


expected_type = {
Expand Down Expand Up @@ -325,3 +342,5 @@ def run_grpc_communication():
run_pipe_communication(args.data_format, args.pipe_name)
elif args.use_grpc:
run_grpc_communication()
elif args.export_onnx:
export_onnx_model()
5 changes: 3 additions & 2 deletions test/mlbridge-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ python $SERVER_FILE --use_grpc --server_port=50155 --silent=True &
SERVER_PID=$!
run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=grpc --test-server-address="0.0.0.0:50155" --silent

echo -e "${BLUE}${BOLD}Testing MLBridge [onnx]${NC}"
run_test $BUILD_DIR/bin/MLCompilerBridgeTest --test-config=onnx

exit $STATUS
# echo "Test [onnx]:"
# $BUILD_DIR/MLCompilerBridgeTest --test-config=onnx
Binary file added test/onnx/dummy_model.onnx
Binary file not shown.

0 comments on commit f75b196

Please sign in to comment.