diff --git a/CMakeLists.txt b/CMakeLists.txt
index f17048d169..32e39a0e1a 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -116,6 +116,12 @@ elseif ((ONNX_USE_PROTOBUF_SHARED_LIBS AND Protobuf_USE_STATIC_LIBS)
     "ONNX_USE_PROTOBUF_SHARED_LIBS and Protobuf_USE_STATIC_LIBS must be opposites of each other.")
+# Use the new MSVC preprocessor to improve standard conformance.
+  set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Zc:preprocessor")
+  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:preprocessor")
 # Suppress warnings in third party code.
diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt
index a852bf30ae..1591c6497d 100644
--- a/src/Conversion/ONNXToKrnl/CMakeLists.txt
+++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt
@@ -42,8 +42,9 @@ add_onnx_mlir_library(OMONNXToKrnl
-  Tensor/Gather.cpp 
-  Tensor/GatherElements.cpp  
+  Tensor/Gather.cpp
+  Tensor/GatherElements.cpp
+  Tensor/GatherND.cpp
diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
index 84fce39f9f..3962a6282f 100644
--- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
+++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
@@ -87,6 +87,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
   populateLoweringONNXTransposeOpPattern(patterns, typeConverter, ctx);
   populateLoweringONNXGatherOpPattern(patterns, typeConverter, ctx);
   populateLoweringONNXGatherElementsOpPattern(patterns, typeConverter, ctx);
+  populateLoweringONNXGatherNDOpPattern(patterns, typeConverter, ctx);
   populateLoweringONNXIdentityOpPattern(patterns, typeConverter, ctx);
   populateLoweringONNXConstantOfShapeOpPattern(patterns, typeConverter, ctx);
   populateLoweringONNXConstantOpPattern(patterns, typeConverter, ctx);
diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
index 2ec7832fd4..c83c9b8756 100644
--- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
+++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
@@ -453,10 +453,9 @@ Value foldOrEmitONNXTransposeOp(ConversionPatternRewriter &rewriter,
 /// Emit MemRef ReinterpretCastOp to create a new view for 'data'.
-/// The new view is created using the given 'memRefType' and 'outputDims'.
+/// The new view is created using the given 'outputDims'.
 Value emitMemRefReinterpretCastOp(ConversionPatternRewriter &rewriter,
-    Location loc, Value data, const MemRefType &memRefType,
-    SmallVectorImpl<IndexExpr> &outputDims) {
+    Location loc, Value data, SmallVectorImpl<IndexExpr> &outputDims) {
   MemRefBuilder createMemRef(rewriter, loc);
   return createMemRef.reinterpretCast(data, outputDims);
diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
index 249fce812d..92819a9620 100644
--- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
+++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
@@ -145,10 +145,9 @@ Value foldOrEmitONNXTransposeOp(ConversionPatternRewriter &rewriter,
     Location loc, Type resultType, Value input, ArrayAttr permAttr);
 /// Emit MemRef ReinterpretCastOp to create a new view for 'data'.
-/// The new view is created using the given 'memRefType' and 'outputDims'.
+/// The new view is created using the given 'outputDims'.
 Value emitMemRefReinterpretCastOp(ConversionPatternRewriter &rewriter,
-    Location loc, Value data, const MemRefType &memRefType,
-    SmallVectorImpl<IndexExpr> &outputDims);
+    Location loc, Value data, SmallVectorImpl<IndexExpr> &outputDims);
 /// Emit krnl iterate to compute argsort of a given MemRef along a given axis.
 /// Output MemRef has the same shape as the input MemRef but is of IndexType.
@@ -321,6 +320,8 @@ void populateLoweringONNXGatherOpPattern(
     RewritePatternSet &, TypeConverter &, MLIRContext *);
 void populateLoweringONNXGatherElementsOpPattern(
     RewritePatternSet &, TypeConverter &, MLIRContext *);
+void populateLoweringONNXGatherNDOpPattern(
+    RewritePatternSet &, TypeConverter &, MLIRContext *);
 void populateLoweringONNXPadConstantValuePadOpPattern(
     RewritePatternSet &, TypeConverter &, MLIRContext *);
 void populateLoweringONNXPadOpPattern(
diff --git a/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp
new file mode 100644
index 0000000000..7c524bfdd7
--- /dev/null
+++ b/src/Conversion/ONNXToKrnl/Tensor/GatherND.cpp
@@ -0,0 +1,269 @@
+ * SPDX-License-Identifier: Apache-2.0
+ */
+//===---------------- GatherND.cpp - Lowering GatherND Op -----------------===//
+// Copyright 2022 The IBM Research Authors.
+// =============================================================================
+// This file lowers the ONNX GatherND Operator to Krnl dialect.
+#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
+#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"
+#include "llvm/Support/Debug.h"
+#include <numeric>
+#define DEBUG_TYPE "gather_nd_onnx_to_krnl"
+using namespace mlir;
+namespace onnx_mlir {
+struct ONNXGatherNDOpLowering : public ConversionPattern {
+  ONNXGatherNDOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
+      : ConversionPattern(
+            typeConverter, ONNXGatherNDOp::getOperationName(), 1, ctx) {}
+  // When true causes injection of print stmts in the generated code.
+  static constexpr bool emitPrintStmts = false;
+  // Debug function used to emit code to print the supplied 'indices'.
+  static void printIndices(
+      StringRef title, const DimsExpr &indices, KrnlBuilder &createKrnl) {
+    llvm::Twine msg(title + ": (");
+    createKrnl.printf(msg.str());
+    int64_t n = (int64_t)indices.size();
+    for (int64_t i = 0; i < n; ++i) {
+      Value val = indices[i].getValue();
+      createKrnl.printf(val, val.getType());
+    }
+    createKrnl.printf(")\n");
+  }
+  LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const final {
+    ONNXGatherNDOpAdaptor operandAdaptor(operands);
+    ONNXGatherNDOp gatherNDOp = cast<ONNXGatherNDOp>(op);
+    Location loc = op->getLoc();
+    MultiDialectBuilder<KrnlBuilder, MathBuilder, MemRefBuilder> create(
+        rewriter, loc);
+    IndexExprScope outerScope(&rewriter, loc);
+    ONNXGatherNDOpShapeHelper shapeHelper(&gatherNDOp, &rewriter,
+        krnl::getDenseElementAttributeFromKrnlValue,
+        krnl::loadDenseElementArrayValueAtIndex);
+    auto shapecomputed = shapeHelper.computeShape(operandAdaptor);
+    assert(succeeded(shapecomputed) && "Could not compute output shape");
+    // Operands and attributes.
+    Value data = operandAdaptor.data();
+    Value indices = operandAdaptor.indices();
+    int64_t b = gatherNDOp.batch_dims();
+    auto indicesType = indices.getType().cast<ShapedType>();
+    auto dataType = data.getType().cast<ShapedType>();
+    ArrayRef<int64_t> indicesShape = indicesType.getShape();
+    ArrayRef<int64_t> dataShape = dataType.getShape();
+    int64_t dataRank = dataShape.size();
+    int64_t indicesRank = indicesShape.size();
+    int64_t indicesLastDim = indicesShape[indicesRank - 1];
+    // Convert the output type to MemRefType.
+    Type convertedType = typeConverter->convertType(*op->result_type_begin());
+    assert(convertedType && convertedType.isa<MemRefType>() &&
+           "Failed to convert type to MemRefType");
+    MemRefType outputMemRefType = convertedType.cast<MemRefType>();
+    ArrayRef<int64_t> outputShape = outputMemRefType.getShape();
+    int64_t outputRank = outputShape.size();
+    // Ensure the operation containts are satisfied.
+    assert(dataRank >= 1 && "The rank of 'data' must be >= 1");
+    assert(indicesRank >= 1 && "The rank of 'indices' must be >= 1");
+    assert((outputRank == dataRank + indicesRank - indicesLastDim - 1 - b) &&
+           "Incorrect outut rank");
+    assert(b >= 0 && "batch_dim should not be negative");
+    assert(b < std::min(dataRank, indicesRank) &&
+           "batch_dims must be smaller than the min(dataRank, indicesRank)");
+    assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) &&
+           "indices.shape[-1] must be in the range [1, dataRank - b]");
+    // Reshape 'indices' to the 3D shape:
+    //   [batchDimSize, indicesDimsSize, indices.shape[-1]].
+    const int64_t batchDimsSize = std::accumulate(indicesShape.begin(),
+        indicesShape.begin() + b, 1, std::multiplies<int64_t>());
+    const int64_t indicesDimsSize = std::accumulate(indicesShape.begin(),
+        indicesShape.end(), 1, std::multiplies<int64_t>());
+    assert(batchDimsSize >= 0 && "batchDimsSize must be non-negative");
+    assert(indicesDimsSize >= 0 && "indicesDimsSize must be non-negative");
+    LiteralIndexExpr BDS(batchDimsSize),
+        IDS(indicesDimsSize / (batchDimsSize * indicesLastDim)),
+        ILD(indicesLastDim);
+    DimsExpr newIndicesShape = {BDS, IDS, ILD};
+    Value reshapedIndices =
+        create.mem.reinterpretCast(indices, newIndicesShape);
+    LLVM_DEBUG(llvm::dbgs() << "reshapedIndices: " << reshapedIndices << "\n");
+    // Reshape 'data' to shape [batchDimSize, data.shape[b:]]
+    DimsExpr newDataShape = {BDS};
+    for (int64_t i = b; i < dataRank; ++i) {
+      assert(dataShape[i] > 0 && "Cannot support data with dynamic dimensions");
+      LiteralIndexExpr dataDim(dataShape[i]);
+      newDataShape.emplace_back(dataDim);
+    }
+    int64_t reshapedDataRank = newDataShape.size();
+    Value reshapedData = create.mem.reinterpretCast(data, newDataShape);
+    LLVM_DEBUG(llvm::dbgs() << "reshapedData: " << reshapedData << "\n");
+    // Allocate a 1D output buffer.
+    const int64_t outputDimsSize = std::accumulate(
+        outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
+    Value outputDataBuffer = create.mem.alloc(
+        MemRefType::get({outputDimsSize}, outputMemRefType.getElementType()));
+    // Initialize the index used to store the result values.
+    Value iZero = create.math.constantIndex(0);
+    Value iOne = create.math.constantIndex(1);
+    Value storeIndex =
+        create.mem.alloca(MemRefType::get({}, rewriter.getIndexType()));
+    create.krnl.store(iZero, storeIndex);
+    // for (i,j) in (0..reshapedIndices.shape[0]), 0..reshapedIndices.shape[1])
+    // {
+    //   idx = tuple(reshapedIndices[i][j])
+    //   output.append(reshapedData[(i,) + idx])
+    // }
+    // output.reshape(outputShape)
+    ValueRange loopDef = create.krnl.defineLoops(2);
+    DimsExpr lbs(2, LiteralIndexExpr(0)),
+        ubs = {newIndicesShape[0], newIndicesShape[1]};
+    if (emitPrintStmts) {
+      create.krnl.printTensor("reshapedIndices: ", reshapedIndices);
+      create.krnl.printTensor("reshapedData: ", reshapedData);
+    }
+    create.krnl.iterateIE(loopDef, loopDef, lbs, ubs,
+        [&](KrnlBuilder &createKrnl, ValueRange loopInd) {
+          // Insert code inside the loop.
+          IndexExprScope innerLoopScope(createKrnl);
+          // Access function for 'reshapedIndices'. The first 2 indices are
+          // equal to the loop indexes.
+          DimsExpr reshapedIndicesAccessFct;
+          getIndexExprList<DimIndexExpr>(loopInd, reshapedIndicesAccessFct);
+          // Access function for 'reshapedData'. The first index is equal to the
+          // first loop index.
+          DimsExpr reshapedDataAccessFct;
+          IndexExpr ind = SymbolIndexExpr(loopInd[0]);
+          reshapedDataAccessFct.emplace_back(ind);
+          // The last index of the access function for 'reshapedIndices' is
+          // given by the values of indices.shape[-1].
+          // The loaded values from 'reshapedIndices' are the next set of
+          // indices to push to the `reshapedDataAccessFct`.
+          for (unsigned i = 0; i < indicesLastDim; ++i) {
+            IndexExpr ind = LiteralIndexExpr(i);
+            reshapedIndicesAccessFct.emplace_back(ind);
+            if (emitPrintStmts)
+              printIndices("indices", reshapedIndicesAccessFct, createKrnl);
+            Value indexVal =
+                createKrnl.loadIE(reshapedIndices, reshapedIndicesAccessFct);
+            reshapedIndicesAccessFct.pop_back();
+            if (emitPrintStmts) {
+              createKrnl.printf("index = ", indexVal, indexVal.getType());
+              createKrnl.printf("\n");
+            }
+            IndexExpr index = NonAffineIndexExpr(indexVal);
+            reshapedDataAccessFct.emplace_back(index);
+          }
+          if (indicesLastDim == dataRank - b) {
+            // When indices.shape[-1] is equal to (rank(data) - b) the
+            // `reshapedDataAccessFct` computed so far has the same number of
+            // indices as the rank of 'reshapedData'.
+            assert((int64_t)reshapedDataAccessFct.size() == reshapedDataRank &&
+                   "Access function should have the same rank as reshapedData");
+            if (emitPrintStmts)
+              printIndices("data indices", reshapedDataAccessFct, createKrnl);
+            // Gather value from the 'data' tensor and store it into
+            // 'outputDataBuffer'.
+            Value val = createKrnl.loadIE(reshapedData, reshapedDataAccessFct);
+            Value storeIndexVal = createKrnl.load(storeIndex);
+            createKrnl.store(val, outputDataBuffer, storeIndexVal);
+            // Bump up the storeIndex.
+            createKrnl.store(create.math.add(storeIndexVal, iOne), storeIndex);
+          } else {
+            assert((indicesLastDim < dataRank - b) &&
+                   "Expecting indices.shape[-1] to be smaller than "
+                   "rank(indices) - b");
+            // When indices.shape[-1] is less than (rank(data) - b) the
+            // `reshapedDataAccessFct` computed so far yields a slice which
+            // needs to be inserted into the output buffer.
+            int64_t reshapedDataLastDim = dataShape[dataRank - 1];
+            for (int64_t i = 0; i < reshapedDataLastDim; ++i) {
+              IndexExpr ind = LiteralIndexExpr(i);
+              reshapedDataAccessFct.emplace_back(ind);
+              assert(
+                  (int64_t)reshapedDataAccessFct.size() == reshapedDataRank &&
+                  "Access function should have the same rank as reshapedData");
+              if (emitPrintStmts)
+                printIndices("data indices", reshapedDataAccessFct, createKrnl);
+              // Gather value from the 'data' tensor and store it into
+              // 'outputDataBuffer'.
+              Value val =
+                  createKrnl.loadIE(reshapedData, reshapedDataAccessFct);
+              reshapedDataAccessFct.pop_back();
+              if (emitPrintStmts) {
+                createKrnl.printf("val = ", val, val.getType());
+                createKrnl.printf("\n");
+              }
+              Value storeIndexVal = createKrnl.load(storeIndex);
+              createKrnl.store(val, outputDataBuffer, storeIndexVal);
+              // Bump up the storeIndex.
+              createKrnl.store(
+                  create.math.add(storeIndexVal, iOne), storeIndex);
+            }
+          }
+        });
+    // Finally reshape 'outputDataBuffer' to the shape of the output.
+    DimsExpr newOutputShape;
+    for (int64_t dim : outputShape) {
+      LiteralIndexExpr outputDim(dim);
+      newOutputShape.emplace_back(outputDim);
+    }
+    Value reshapedOutput =
+        create.mem.reinterpretCast(outputDataBuffer, newOutputShape);
+    LLVM_DEBUG(llvm::dbgs() << "reshapedOutput: " << reshapedOutput << "\n");
+    rewriter.replaceOp(op, reshapedOutput);
+    return success();
+  }
+void populateLoweringONNXGatherNDOpPattern(RewritePatternSet &patterns,
+    TypeConverter &typeConverter, MLIRContext *ctx) {
+  patterns.insert<ONNXGatherNDOpLowering>(typeConverter, ctx);
+} // namespace onnx_mlir
diff --git a/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp b/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp
index 3ecaa75a1f..583b6beffe 100644
--- a/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp
+++ b/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp
@@ -51,7 +51,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
     // Lower to ReinterpretCastOp so that the data is never copied or modified.
     Value newView = emitMemRefReinterpretCastOp(
-        rewriter, loc, data, memRefType, shapeHelper.dimsForOutput());
+        rewriter, loc, data, shapeHelper.dimsForOutput());
     LLVM_DEBUG(llvm::dbgs() << "newView: " << newView << "\n");
     rewriter.replaceOp(op, newView);
diff --git a/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp b/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp
index 2d0f99d523..d674d2b12a 100644
--- a/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp
+++ b/src/Conversion/ONNXToKrnl/Tensor/Squeeze.cpp
@@ -26,15 +26,9 @@ LogicalResult ONNXSqueezeOpLoweringCommon(Operation *op,
   Adaptor operandAdaptor(operands);
   Op squeezeOp = dyn_cast_or_null<Op>(op);
-  auto loc = op->getLoc();
+  Location loc = op->getLoc();
   Value data = operandAdaptor.data();
-  // Convert the output type to MemRefType.
-  Type convertedType = typeConverter->convertType(*op->result_type_begin());
-  assert(convertedType && convertedType.isa<MemRefType>() &&
-         "Failed to convert type to MemRefType");
-  MemRefType memRefType = convertedType.cast<MemRefType>();
   ShapeHelper shapeHelper(&squeezeOp, &rewriter,
@@ -43,7 +37,7 @@ LogicalResult ONNXSqueezeOpLoweringCommon(Operation *op,
   // Lower to ReinterpretCastOp so that the data is never copied or modified.
   Value newView = emitMemRefReinterpretCastOp(
-      rewriter, loc, data, memRefType, shapeHelper.dimsForOutput());
+      rewriter, loc, data, shapeHelper.dimsForOutput());
   rewriter.replaceOp(op, newView);
   return success();
diff --git a/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp b/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp
index 7ce98f3e4f..76ad2e80b1 100644
--- a/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp
+++ b/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp
@@ -26,15 +26,9 @@ LogicalResult ONNXUnsqueezeOpLoweringCommon(Operation *op,
   Adaptor operandAdaptor(operands);
   Op unsqueezeOp = dyn_cast_or_null<Op>(op);
-  auto loc = op->getLoc();
+  Location loc = op->getLoc();
   Value data = operandAdaptor.data();
-  // Convert the output type to MemRefType.
-  Type convertedType = typeConverter->convertType(*op->result_type_begin());
-  assert(convertedType && convertedType.isa<MemRefType>() &&
-         "Failed to convert type to MemRefType");
-  MemRefType memRefType = convertedType.cast<MemRefType>();
   ShapeHelper shapeHelper(&unsqueezeOp, &rewriter,
@@ -43,7 +37,7 @@ LogicalResult ONNXUnsqueezeOpLoweringCommon(Operation *op,
   // Lower to ReinterpretCastOp so that the data is never copied or modified.
   Value newView = emitMemRefReinterpretCastOp(
-      rewriter, loc, data, memRefType, shapeHelper.dimsForOutput());
+      rewriter, loc, data, shapeHelper.dimsForOutput());
   rewriter.replaceOp(op, newView);
   return success();
diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp
index 5c7af54d06..cb743b2326 100644
--- a/src/Dialect/Krnl/DialectBuilder.cpp
+++ b/src/Dialect/Krnl/DialectBuilder.cpp
@@ -25,6 +25,38 @@ using namespace mlir;
 namespace onnx_mlir {
+static StringRef getFormat(const Type &inputType) {
+  StringRef format;
+  TypeSwitch<Type>(inputType)
+      .Case<Float16Type>([&](Float16Type) { format = "%g"; })
+      .Case<Float32Type>([&](Float32Type) { format = "%f"; })
+      .Case<Float64Type>([&](Float64Type) { format = "%f"; })
+      .Case<IntegerType>([&](IntegerType type) {
+        switch (type.getWidth()) {
+        case 1:
+        case 8:
+        case 16:
+        case 32:
+          format = type.isUnsigned() ? "%u" : "%d";
+          break;
+        case 64:
+          format = type.isUnsigned() ? "%llu" : "%lld";
+          break;
+        }
+      })
+      .Case<IndexType>([&](IndexType) { format = "%lld"; })
+      .Case<onnx_mlir::krnl::StringType>(
+          [&](onnx_mlir::krnl::StringType) { format = "%s"; })
+      .Case<LLVM::LLVMPointerType>(
+          [&](LLVM::LLVMPointerType) { format = "%s"; })
+      .Default([&](Type type) {
+        llvm::errs() << "type: " << type << "\n";
+        llvm_unreachable("Unhandled type");
+      });
+  return format;
 //====---------------- Support for Krnl Builder ----------------------===//
 Value KrnlBuilder::load(Value memref, ValueRange indices) const {
@@ -208,37 +240,14 @@ void KrnlBuilder::printf(StringRef msg) const {
 void KrnlBuilder::printf(StringRef msg, Value input, Type inputType) const {
-  StringRef format;
-  TypeSwitch<Type>(inputType)
-      .Case<mlir::Float16Type>([&](mlir::Float16Type) { format = "%g\n"; })
-      .Case<mlir::Float32Type>([&](mlir::Float32Type) { format = "%g\n"; })
-      .Case<mlir::Float64Type>([&](mlir::Float64Type) { format = "%g\n"; })
-      .Case<IntegerType>([&](IntegerType type) {
-        switch (type.getWidth()) {
-        case 1:
-        case 8:
-        case 16:
-        case 32:
-          format = type.isUnsigned() ? "%u\n" : "%d\n";
-          break;
-        case 64:
-          format = type.isUnsigned() ? "%llu\n" : "%lld\n";
-          break;
-        }
-      })
-      .Case<IndexType>([&](IndexType) { format = "%lld\n"; })
-      .Case<onnx_mlir::krnl::StringType>(
-          [&](onnx_mlir::krnl::StringType) { format = "%s\n"; })
-      .Case<LLVM::LLVMPointerType>(
-          [&](LLVM::LLVMPointerType) { format = "%s\n"; })
-      .Default([&](Type type) {
-        llvm::errs() << "type: " << type << "\n";
-        llvm_unreachable("Unhandled type");
-      });
+  StringRef format = getFormat(inputType);
   std::string concat(msg.str() + format.str());
   StringRef newFormat(concat);
   b.create<KrnlPrintOp>(loc, newFormat, input);
+void KrnlBuilder::printf(Value input, Type inputType) const {
+  StringRef format = getFormat(inputType);
+  b.create<KrnlPrintOp>(loc, format, input);
 } // namespace onnx_mlir
diff --git a/src/Dialect/Krnl/DialectBuilder.hpp b/src/Dialect/Krnl/DialectBuilder.hpp
index c2c2317bda..f253979996 100644
--- a/src/Dialect/Krnl/DialectBuilder.hpp
+++ b/src/Dialect/Krnl/DialectBuilder.hpp
@@ -136,6 +136,7 @@ struct KrnlBuilder : public DialectBuilder {
   void printf(mlir::StringRef msg) const;
   void printf(
       mlir::StringRef msg, mlir::Value input, mlir::Type inputType) const;
+  void printf(mlir::Value input, mlir::Type inputType) const;
   // Onnx-mlir runtime functions.
   void randomNormal(mlir::Value alloc, mlir::Value numberOfRandomValues,
diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt
index 49d5634a15..a0779def68 100644
--- a/src/Dialect/ONNX/CMakeLists.txt
+++ b/src/Dialect/ONNX/CMakeLists.txt
@@ -31,6 +31,7 @@ add_onnx_mlir_library(OMONNXOps
+  ShapeInference/GatherND.cpp  
diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp
index 3d35575c98..c10260f919 100644
--- a/src/Dialect/ONNX/ONNXOps.cpp
+++ b/src/Dialect/ONNX/ONNXOps.cpp
@@ -31,6 +31,7 @@
 #include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"
 #include "src/Support/Diagnostic.hpp"
+#include <algorithm>
 #include <string>
 using namespace mlir;
@@ -3374,6 +3375,114 @@ LogicalResult ONNXGatherElementsOp::inferShapes(
       ONNXGatherElementsOp, ONNXGatherElementsOpAdaptor>(*this, elementType);
+// GatherND
+LogicalResult ONNXGatherNDOp::verify() {
+  ONNXGatherNDOpAdaptor operandAdaptor(*this);
+  if (llvm::any_of(operandAdaptor.getOperands(),
+          [](const Value &op) { return !hasShapeAndRank(op); }))
+    return success(); // Won't be able to do any checking at this stage.
+  // Get operands and attributes.
+  Value data = operandAdaptor.data();
+  Value indices = operandAdaptor.indices();
+  auto dataType = data.getType().cast<ShapedType>();
+  auto indicesType = indices.getType().cast<ShapedType>();
+  int64_t dataRank = dataType.getRank();
+  int64_t indicesRank = indicesType.getRank();
+  int64_t b = batch_dims();
+  // 'data' and 'indices' must have rank strictly greater than zero.
+  if (dataRank < 1)
+    return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError(
+        *this->getOperation(), data, dataRank, "> 0");
+  if (indicesRank < 1)
+    return onnx_mlir::Diagnostic::emitOperandHasUnexpectedRankError(
+        *this->getOperation(), indices, indicesRank, "> 0");
+  ArrayRef<int64_t> dataShape = dataType.getShape();
+  ArrayRef<int64_t> indicesShape = indicesType.getShape();
+  int64_t indicesLastDim = indicesShape[indicesRank - 1];
+  // b must be smaller than min(rank(data), rank(indices).
+  int64_t minDataAndIndicesRank = std::min(dataRank, indicesRank);
+  if (b >= minDataAndIndicesRank)
+    return onnx_mlir::Diagnostic::emitAttributeOutOfRangeError(
+        *this->getOperation(), "batch_dims", b,
+        onnx_mlir::Diagnostic::Range<int64_t>(0, minDataAndIndicesRank - 1));
+  // The first b dimensions of the shape of 'indices' and 'data' must be equal.
+  for (int64_t i = 0; i < b; ++i) {
+    int64_t dataDim = dataShape[i];
+    int64_t indicesDim = indicesShape[i];
+    if (indicesDim < 0 || dataDim < 0)
+      continue;
+    if (indicesDim != dataDim)
+      return onnx_mlir::Diagnostic::emitDimensionHasUnexpectedValueError(
+          *this->getOperation(), indices, i, indicesShape[i],
+          std::to_string(dataShape[i]));
+  }
+  // Let r = rank(data), indices.shape[-1] must be in the range [1, r-b].
+  if (indicesLastDim == 0)
+    return onnx_mlir::Diagnostic::emitDimensionHasUnexpectedValueError(
+        *this->getOperation(), indices, indicesRank - 1, indicesLastDim,
+        ">= 1");
+  if (indicesLastDim > dataRank - b)
+    return onnx_mlir::Diagnostic::emitDimensionHasUnexpectedValueError(
+        *this->getOperation(), indices, indicesRank - 1, indicesLastDim,
+        "<= " + std::to_string(dataRank - b));
+  // All values in 'indices' are expected to satisfy the inequality:
+  //   -data.shape[i] <= indices[...,i] <= (data.shape[i]-1)].
+  for (int64_t i = 0; i < indicesRank; ++i) {
+    int64_t dataDimAtAxis = dataShape[i];
+    if (dataDimAtAxis < 0)
+      continue;
+    if (DenseElementsAttr valueAttribute =
+            getDenseElementAttributeFromONNXValue(indices))
+      for (IntegerAttr value : valueAttribute.getValues<IntegerAttr>()) {
+        static int n = 0;
+        int64_t index = value.getInt();
+        if (index < -dataDimAtAxis || index > dataDimAtAxis - 1)
+          return onnx_mlir::Diagnostic::emitAttributeOutOfRangeError(
+              *this->getOperation(), "indices[" + std::to_string(n) + "]",
+              index,
+              onnx_mlir::Diagnostic::Range<int64_t>(
+                  -dataDimAtAxis, dataDimAtAxis - 1));
+        n++;
+      }
+  }
+  return success();
+LogicalResult ONNXGatherNDOp::inferShapes(
+    std::function<void(mlir::Region &)> doShapeInference) {
+  // Cannot infer the shape of the output if the inputs shape is not yet known.
+  if (llvm::any_of(
+          this->getOperands(), [](Value op) { return !hasShapeAndRank(op); }))
+    return success();
+  // The output rank is given by:
+  //   rank(output) = rank(indices) + rank(data) - indices_shape[-1] - 1 - b.
+  // Therefore 'indices.shape[-1]' must be known in order to compute the output
+  // shape.
+  ArrayRef<int64_t> indicesShape =
+      indices().getType().cast<ShapedType>().getShape();
+  int64_t indicesRank = indicesShape.size();
+  if (indicesShape[indicesRank - 1] < 0)
+    return success(); // cannot infer the oputput shape yet.
+  auto elementType = data().getType().cast<ShapedType>().getElementType();
+  return shapeHelperInferShapes<ONNXGatherNDOpShapeHelper, ONNXGatherNDOp,
+      ONNXGatherNDOpAdaptor>(*this, elementType);
+  return success();
 // ConstantOfShape
@@ -3892,10 +4001,6 @@ LogicalResult ONNXFloorOp::inferShapes(
   return success();
-LogicalResult ONNXGatherNDOp::inferShapes(
-    std::function<void(mlir::Region &)> doShapeInference) {
-  return emitError(NOT_IMPLEMENTED_MESSAGE);
 LogicalResult ONNXGreaterOp::inferShapes(
     std::function<void(mlir::Region &)> doShapeInference) {
   Builder b(getContext());
diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc
index d9c3df2e83..b25f1135ea 100644
--- a/src/Dialect/ONNX/ONNXOps.td.inc
+++ b/src/Dialect/ONNX/ONNXOps.td.inc
@@ -1910,6 +1910,7 @@ def ONNXGatherNDOp:ONNX_Op<"GatherND",
       return {20};
+  let hasVerifier = 1;
 def ONNXGemmOp:ONNX_Op<"Gemm",
diff --git a/src/Dialect/ONNX/ShapeInference/GatherND.cpp b/src/Dialect/ONNX/ShapeInference/GatherND.cpp
new file mode 100644
index 0000000000..cffb7fe7fc
--- /dev/null
+++ b/src/Dialect/ONNX/ShapeInference/GatherND.cpp
@@ -0,0 +1,70 @@
+ * SPDX-License-Identifier: Apache-2.0
+ */
+//===---------- GatherND.cpp - Shape Inference for GatherND Op ------------===//
+// This file implements shape inference for the ONNX GatherND Operator.
+#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"
+#include <algorithm>
+using namespace mlir;
+namespace onnx_mlir {
+LogicalResult ONNXGatherNDOpShapeHelper::computeShape(
+    ONNXGatherNDOpAdaptor operandAdaptor) {
+  Value data = operandAdaptor.data();
+  Value indices = operandAdaptor.indices();
+  MemRefBoundsIndexCapture dataBounds(data);
+  MemRefBoundsIndexCapture indicesBounds(indices);
+  DimsExpr dataDims, indicesDims;
+  dataBounds.getDimList(dataDims);
+  indicesBounds.getDimList(indicesDims);
+  int64_t dataRank = dataDims.size();
+  int64_t indicesRank = indicesDims.size();
+  int64_t b = op->batch_dims();
+  assert(indices.getType().isa<ShapedType>() && "Expecting a shaped type");
+  auto indicesType = indices.getType().cast<ShapedType>();
+  ArrayRef<int64_t> indicesShape = indicesType.getShape();
+  int64_t indicesLastDim = indicesShape[indicesRank - 1];
+  int64_t outputRank = dataRank + indicesRank - indicesLastDim - 1 - b;
+  // Ensure the operator contraints are statisfied.
+  assert(dataRank >= 1 && "dataRank should be >= 1");
+  assert(indicesRank >= 1 && "indicesRank should be >= 1");
+  assert(b >= 0 && "batch_dim should not be negative");
+  assert(b < std::min(dataRank, indicesRank) &&
+         "batch_dims must be smaller than the min(dataRank, indicesRank)");
+  assert((indicesLastDim >= 1 && indicesLastDim <= dataRank - b) &&
+         "indices.shape[-1] must be in the range [1, dataRank - b]");
+  // Save the first 'b' dimension of the shape of the 'indices' tensor.
+  DimsExpr batchDims;
+  for (int64_t i = 0; i < b; ++i)
+    batchDims.emplace_back(indicesDims[i]);
+  // output.shape = batchDims + list(indices.shape)[b:-1]
+  for (int64_t i = 0; i < b; ++i)
+    dimsForOutput().emplace_back(batchDims[i]);
+  for (int64_t i = b; i < indicesRank - 1; ++i)
+    dimsForOutput().emplace_back(indicesDims[i]);
+  // When indices.shape[-1] < data_rank - b,
+  //   output_shape += list(data.shape)[batch_dims + indices.shape[-1]:]
+  if (indicesLastDim < dataRank - b)
+    for (int64_t i = b + indicesLastDim; i < dataRank; ++i)
+      dimsForOutput().emplace_back(dataDims[i]);
+  assert((int64_t)dimsForOutput().size() == outputRank &&
+         "Incorrect shape computation");
+  return success();
+} // namespace onnx_mlir
diff --git a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp
index 46ba4c4c4a..fec182616e 100644
--- a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp
+++ b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.cpp
@@ -417,6 +417,7 @@ template struct ONNXOpShapeHelper<ONNXExpandOp>;
 template struct ONNXOpShapeHelper<ONNXFlattenOp>;
 template struct ONNXOpShapeHelper<ONNXGatherOp>;
 template struct ONNXOpShapeHelper<ONNXGatherElementsOp>;
+template struct ONNXOpShapeHelper<ONNXGatherNDOp>;
 template struct ONNXOpShapeHelper<ONNXGemmOp>;
 template struct ONNXOpShapeHelper<ONNXMatMulOp>;
 template struct ONNXOpShapeHelper<ONNXMaxPoolSingleOutOp>;
diff --git a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp
index 5cfa392762..3c2f6e8160 100644
--- a/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp
+++ b/src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp
@@ -230,6 +230,7 @@ DECLARE_SHAPE_HELPER(ONNXDepthToSpaceOp)
diff --git a/src/Runtime/OMTensor.inc b/src/Runtime/OMTensor.inc
index a31d87284e..6658d7a54f 100644
--- a/src/Runtime/OMTensor.inc
+++ b/src/Runtime/OMTensor.inc
@@ -14,6 +14,7 @@
 #ifdef __cplusplus
+#include <array>
 #include <cassert>
 #include <complex>
 #include <map>
@@ -161,10 +162,10 @@ static inline void printElement(
     printf("%lld", (long long)((int64_t *)dataPtr)[elemOffset]);
-    printf("%g", ((float *)dataPtr)[elemOffset]);
+    printf("%f", ((float *)dataPtr)[elemOffset]);
-    printf("%g", ((double *)dataPtr)[elemOffset]);
+    printf("%f", ((double *)dataPtr)[elemOffset]);
     printf("%s", ((const char **)dataPtr)[elemOffset]);
@@ -423,63 +424,78 @@ void omTensorPrint(const char *msg, const OMTensor *tensor) {
   printf("\trank = %lld\n", (long long)rank);
   printf("\tdataType = %s\n", getDataTypeName(dataType));
   printf("\tnumElems = %lld\n", (long long)omTensorGetNumElems(tensor));
+  printf("\tshape: ");
+  for (int64_t i = 0; i < rank; i++)
+    printf("[%lld]", (long long)shape[i]);
+  printf("\n");
   printf("\tstrides: ");
   for (int64_t i = 0; i < rank; i++)
     printf("[%lld]", (long long)strides[i]);
-  printf("\tdata: ([");
+#define LOOP_1(INDEX, IV, UB)                                                  \
+  printf("[");                                                                 \
+  for (int64_t IV = 0; IV < UB; ++IV) {                                        \
+    if (IV)                                                                    \
+      printf(", ");                                                            \
+    indexes[INDEX] = IV;                                                       \
+    int64_t elemOffset = computeElemOffset(tensor->_strides, indexes, rank);   \
+    printElement(dataPtr, elemOffset, dataType);                               \
+  }                                                                            \
+  printf("]");
+#define LOOP_2(INDEX, IV, UB, ...)                                             \
+  printf("[");                                                                 \
+  for (int64_t IV = 0; IV < UB; ++IV) {                                        \
+    if (IV)                                                                    \
+      printf(", ");                                                            \
+    indexes[INDEX] = IV;                                                       \
+    LOOP_1(INDEX + 1, __VA_ARGS__)                                             \
+  }                                                                            \
+  printf("]");
+#define LOOP_3(INDEX, IV, UB, ...)                                             \
+  printf("[");                                                                 \
+  for (int64_t IV = 0; IV < UB; ++IV) {                                        \
+    if (IV)                                                                    \
+      printf(", ");                                                            \
+    indexes[INDEX] = IV;                                                       \
+    LOOP_2(INDEX + 1, __VA_ARGS__)                                             \
+  }                                                                            \
+  printf("]");
+#define LOOP_4(INDEX, IV, UB, ...)                                             \
+  printf("[");                                                                 \
+  for (int64_t IV = 0; IV < UB; ++IV) {                                        \
+    if (IV)                                                                    \
+      printf(", ");                                                            \
+    indexes[INDEX] = IV;                                                       \
+    LOOP_3(INDEX + 1, __VA_ARGS__)                                             \
+  }                                                                            \
+  printf("]");
+  printf("\tdata: (");
   switch (rank) {
-  case 1:
-    for (int64_t i = 0; i < shape[0]; ++i) {
-      if (i)
-        printf(", ");
-      int64_t indexes[] = {i};
-      int64_t elemOffset = computeElemOffset(tensor->_strides, indexes, rank);
-      printElement(dataPtr, elemOffset, dataType);
-    }
-    break;
-  case 2:
-    for (int64_t i = 0; i < shape[0]; ++i) {
-      if (i)
-        printf(", ");
-      printf("[");
-      for (int64_t j = 0; j < shape[1]; ++j) {
-        if (j)
-          printf(", ");
-        int64_t indexes[] = {i, j};
-        int64_t elemOffset = computeElemOffset(tensor->_strides, indexes, rank);
-        printElement(dataPtr, elemOffset, dataType);
-      }
-      printf("]");
-    }
-    break;
-  case 3:
-    for (int64_t i = 0; i < shape[0]; ++i) {
-      if (i)
-        printf(", ");
-      printf("[");
-      for (int64_t j = 0; j < shape[1]; ++j) {
-        if (j)
-          printf(", ");
-        printf("[");
-        for (int64_t k = 0; k < shape[2]; ++k) {
-          if (k)
-            printf(", ");
-          int64_t indexes[] = {i, j, k};
-          int64_t elemOffset =
-              computeElemOffset(tensor->_strides, indexes, rank);
-          printElement(dataPtr, elemOffset, dataType);
-        }
-        printf("]");
-      }
-      printf("]");
-    }
-    break;
+  case 1: {
+    int64_t indexes[1];
+    LOOP_1(0, i, shape[0])
+  } break;
+  case 2: {
+    int64_t indexes[2];
+    LOOP_2(0, i, shape[0], j, shape[1])
+  } break;
+  case 3: {
+    int64_t indexes[3];
+    LOOP_3(0, i, shape[0], j, shape[1], k, shape[2])
+  } break;
+  case 4: {
+    int64_t indexes[4];
+    LOOP_4(0, i, shape[0], j, shape[1], k, shape[2], l, shape[3])
+  } break;
     assert(false && "not implemented");
-  printf("])\n");
+  printf(")\n");
 #ifdef __cplusplus
@@ -657,8 +673,8 @@ inline bool omTensorAreTwoOmtsClose(
       eqAllclose.begin(), eqAllclose.end(), [&](T eq) { return eq >= 0; });
   if (!satisfied) {
-    // Figure out where and what went wrong, this can be slow; but hopefully we
-    // don't need this often.
+    // Figure out where and what went wrong, this can be slow; but hopefully
+    // we don't need this often.
     for (const auto &idx : omTensorComputeIndexSet(a)) {
       T aElem = omTensorGetElem<T>(a, idx);
       T bElem = omTensorGetElem<T>(b, idx);
diff --git a/src/Support/Diagnostic.hpp b/src/Support/Diagnostic.hpp
index 3b44e6a5da..a044b452fa 100644
--- a/src/Support/Diagnostic.hpp
+++ b/src/Support/Diagnostic.hpp
@@ -36,7 +36,7 @@ class Diagnostic {
     Range(T min, T max) : min(min), max(max) {
-      assert(min < max && "Illegal range");
+      assert(min <= max && "Illegal range");
diff --git a/test/backend/inference_backend.py b/test/backend/inference_backend.py
index 5bef3c7757..d0b42d94b6 100644
--- a/test/backend/inference_backend.py
+++ b/test/backend/inference_backend.py
@@ -298,6 +298,11 @@ def get_test_models():
         "test_gather_elements_1_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
         "test_gather_elements_negative_indices_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
+        # GatherND
+        "test_gathernd_example_int32_cpu": {STATIC_SHAPE:{}, CONSTANT_INPUT:{-1}},
+        "test_gathernd_example_float32_cpu": {STATIC_SHAPE:{}, CONSTANT_INPUT:{-1}},
+        "test_gathernd_example_int32_batch_dim1_cpu": {STATIC_SHAPE:{}, CONSTANT_INPUT:{-1}},
         # Gemm
         "test_gemm_all_attributes_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
         "test_gemm_alpha_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir
index 04f9ae252b..e05e6013a6 100644
--- a/test/mlir/onnx/invalid.mlir
+++ b/test/mlir/onnx/invalid.mlir
@@ -164,6 +164,7 @@ func @test_hardmax_verifier_1(%arg0: tensor<2x2xf32>) -> tensor<*xf32> {
 // -----
+// COM: Rank of 'data' has to be >=1
 func @test_gather_elements_verifier_1(%arg0 : tensor<f32>, %arg1 : tensor<5xi64>) -> tensor<*xf32> {
   // expected-error @+1 {{onnx.GatherElements: operand '<block argument> of type 'tensor<f32>' at index: 0' has rank 0, rank should be > 0}}  
   %1 = "onnx.GatherElements"(%arg0, %arg1) {axis = 4 : si64} : (tensor<f32>, tensor<5xi64>)  -> tensor<*xf32>
@@ -172,6 +173,7 @@ func @test_gather_elements_verifier_1(%arg0 : tensor<f32>, %arg1 : tensor<5xi64>
 // -----
+// COM: Rank of 'indices' must be equal to the rank of `data`.
 func @test_gather_elements_verifier_2(%arg0 : tensor<5xf32>, %arg1 : tensor<5x3xi64>) -> tensor<*xf32> {
   // expected-error @+1 {{onnx.GatherElements: operand '<block argument> of type 'tensor<5x3xi64>' at index: 1' has rank 2, rank should be 1}}
   %1 = "onnx.GatherElements"(%arg0, %arg1) {axis = 4 : si64} : (tensor<5xf32>, tensor<5x3xi64>)  -> tensor<*xf32>
@@ -180,6 +182,7 @@ func @test_gather_elements_verifier_2(%arg0 : tensor<5xf32>, %arg1 : tensor<5x3x
 // -----
+// COM: 'axis' valid range is [-r, r-1], where r = rank(data).
 func @test_gather_elements_verifier_3(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x1x32xi64>) -> tensor<*xf32> {
   // expected-error @+1 {{onnx.GatherElements: 'axis' value is 4, accepted range is [-4, 3]}}
   %1 = "onnx.GatherElements"(%arg0, %arg1) {axis = 4 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x1x32xi64>)  -> tensor<*xf32>
@@ -188,6 +191,7 @@ func @test_gather_elements_verifier_3(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tens
 // -----
+// COM:  All index values in 'indices' are expected to be within bounds [-s, s-1] along axis of size s.
 func @test_gather_elements_verifier_4(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>) -> tensor<*xf32> {
   // expected-error @+2 {{onnx.GatherElements: 'indices' value is 3, accepted range is [-3, 2]}}
   %indices = "onnx.Constant"() {value = dense<[3]> : tensor<1xi64>} : () -> tensor<1xi64>
@@ -197,6 +201,58 @@ func @test_gather_elements_verifier_4(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf3
 // -----
+// COM: Rank of 'data' has to be >=1
+func @test_gatherND_verifier_1(%arg0 : tensor<f32>, %arg1 : tensor<5xi64>) -> tensor<*xf32> {
+  // expected-error @+1 {{onnx.GatherND: operand '<block argument> of type 'tensor<f32>' at index: 0' has rank 0, rank should be > 0}}  
+  %1 = "onnx.GatherND"(%arg0, %arg1) : (tensor<f32>, tensor<5xi64>)  -> tensor<*xf32>
+// -----
+// COM: Rank of 'indices' has to be >=1
+func @test_gatherND_verifier_2(%arg0 : tensor<2xf32>, %arg1 : tensor<i64>) -> tensor<*xf32> {
+  // expected-error @+1 {{onnx.GatherND: operand '<block argument> of type 'tensor<i64>' at index: 1' has rank 0, rank should be > 0}}  
+  %1 = "onnx.GatherND"(%arg0, %arg1) : (tensor<2xf32>, tensor<i64>)  -> tensor<*xf32>
+// -----
+// COM: The value batch_dims must be smaller than the minimum of rank(data) and rank(indices).
+func @test_gatherND_verifier_3(%arg0 : tensor<1x2x3xf32>, %arg1 : tensor<2x2x2x2xi64>) -> tensor<*xf32> {
+  // expected-error @+1 {{onnx.GatherND: 'batch_dims' value is 3, accepted range is [0, 2]}}
+  %1 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 3 : si64}: (tensor<1x2x3xf32>, tensor<2x2x2x2xi64>)  -> tensor<*xf32>
+// -----
+// COM: The first 'batchDims' dimensions of the shape of the 'indices' and 'data' tensors must be equal.
+func @test_gatherND_verifier_4(%arg0 : tensor<2x2x3x4xf32>, %arg1 : tensor<2x3x2xi64>) -> tensor<*xf32> {
+  // expected-error @+1 {{onnx.GatherND: operand '<block argument> of type 'tensor<2x3x2xi64>' at index: 1' has dimension at index 1 with value 3, value should be 2}}
+  %1 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 2 : si64} : (tensor<2x2x3x4xf32>, tensor<2x3x2xi64>)  -> tensor<*xf32>
+  "std.return"(%1) : (tensor<*xf32>) -> ()
+// -----
+// COM: The last dimension of the 'indices' shape must be a value in the range [1, rank(data)-batch_dims].
+func @test_gatherND_verifier_5(%arg0 : tensor<1x2x3x4xf32>, %arg1 : tensor<1x4xi64>) -> tensor<*xf32> {
+  // expected-error @+1 {{onnx.GatherND: operand '<block argument> of type 'tensor<1x4xi64>' at index: 1' has dimension at index 1 with value 4, value should be <= 3}}
+  %1 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 1 : si64} : (tensor<1x2x3x4xf32>, tensor<1x4xi64>)  -> tensor<*xf32>
+  "std.return"(%1) : (tensor<*xf32>) -> ()
+// -----
+// COM: All values in 'indices' are expected to satisfy the inequality:
+// COM:   -data.shape[i] <= indices[...,i] <= (data.shape[i]-1)].
+func @test_gatherND_verifier_6(%arg0 : tensor<3x4x4x4xf32>) -> tensor<*xf32> {
+  // expected-error @+2 {{onnx.GatherND: 'indices[0]' value is 3, accepted range is [-3, 2]}}
+  %indices = "onnx.Constant"() {value = dense<[3,2,2]> : tensor<3xi64>} : () -> tensor<3x3x2xi64>
+  %1 = "onnx.GatherND"(%arg0, %indices) : (tensor<3x4x4x4xf32>, tensor<3x3x2xi64>)  -> tensor<*xf32>
+// -----
 func @test_onehotencoder_verifier_1(%arg0: tensor<2x2xf32>) -> tensor<*xf32> {
   // expected-error @+1 {{'onnx.OneHotEncoder' op input is a tensor of float, int32, or double, but no cats_int64s attribute}}
   %1 = "onnx.OneHotEncoder"(%arg0) { cats_string = ["a","b","c"]} : (tensor<2x2xf32>) -> tensor<*xf32>
diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir
index 6e3a2b65a9..ca2e30669f 100644
--- a/test/mlir/onnx/onnx_lowering.mlir
+++ b/test/mlir/onnx/onnx_lowering.mlir
@@ -2532,6 +2532,82 @@ func @test_resize2(%arg0 : tensor<3x4xf32>) -> tensor<*xf32> {
 // CHECK:           return [[RES]] : memref<2xi64>
+// COM: Test GatherND with indices_shape[-1] == rank(data) - batch_dims
+func @test_gather_nd_1(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2x2xi64>) -> tensor<2xf32> {
+  %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<2x2xi64>) -> tensor<2xf32>
+  "std.return"(%0) : (tensor<2xf32>) -> ()
+// CHECK-LABEL:  @test_gather_nd_1
+// CHECK-SAME:   ([[PARAM_0:%.+]]: memref<2x2xf32>, [[PARAM_1:%.+]]: memref<2x2xi64>) -> memref<2xf32> {
+// CHECK:           [[RESHAPED_INDICES:%.+]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1, 2, 2], strides: [4, 2, 1] : memref<2x2xi64> to memref<1x2x2xi64>
+// CHECK:           [[RESHAPED_DATA:%.+]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1, 2, 2], strides: [4, 2, 1] : memref<2x2xf32> to memref<1x2x2xf32>
+// CHECK-DAG:       [[RES_BUFFER:%.+]] = memref.alloc() : memref<2xf32>
+// CHECK-DAG:       [[RES_BUFFER_INDEX:%.+]] = memref.alloca() : memref<index>
+// CHECK-DAG:       [[CST_0_0:%.+]] = arith.constant 0 : index
+// CHECK-DAG:       [[CST_1_0:%.+]] = arith.constant 1 : index
+// CHECK:           krnl.store [[CST_0_0]], [[RES_BUFFER_INDEX]][] : memref<index>
+// CHECK:           [[LOOP:%.+]]:2 = krnl.define_loops 2
+// CHECK:           krnl.iterate([[LOOP]]#0, [[LOOP]]#1) with ([[LOOP]]#0 -> [[I_0:%.+]] = 0 to 1, [[LOOP]]#1 -> [[I_1:%.+]] = 0 to 2){
+// CHECK-DAG:         [[IV:%.+]]:2 = krnl.get_induction_var_value([[LOOP]]#0, [[LOOP]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
+// CHECK:             [[CST_0_1:%.+]] = arith.constant 0 : index
+// CHECK:             [[LOAD_INDEX_1:%.+]] = krnl.load [[RESHAPED_INDICES]][[[IV]]#0, [[IV]]#1, [[CST_0_1]]] : memref<1x2x2xi64>
+// CHECK-DAG:         [[INDEX_1:%.+]] = arith.index_cast [[LOAD_INDEX_1]] : i64 to index
+// CHECK-DAG:         [[CST_1_1:%.+]] = arith.constant 1 : index
+// CHECK:             [[LOAD_INDEX_2:%.+]] = krnl.load [[RESHAPED_INDICES]][[[IV]]#0, [[IV]]#1, [[CST_1_1]]] : memref<1x2x2xi64>
+// CHECK:             [[INDEX_2:%.+]] = arith.index_cast [[LOAD_INDEX_2]] : i64 to index
+// CHECK-DAG:         [[DATA_VAL:%.+]] = krnl.load [[RESHAPED_DATA]][[[IV]]#0, [[INDEX_1]], [[INDEX_2]]] : memref<1x2x2xf32>
+// CHECK-DAG:         [[RES_BUFFER_INDEX_VAL:%.+]] = krnl.load [[RES_BUFFER_INDEX]][] : memref<index>
+// CHECK:             krnl.store [[DATA_VAL]], [[RES_BUFFER]][[[RES_BUFFER_INDEX_VAL]]] : memref<2xf32>
+// CHECK:             [[PLUS_ONE:%.+]] = arith.addi [[RES_BUFFER_INDEX_VAL]], [[CST_1_0]] : index
+// CHECK:             krnl.store [[PLUS_ONE]], [[RES_BUFFER_INDEX]][] : memref<index>
+// CHECK:           }
+// CHECK:          [[RES:%.+]] = memref.reinterpret_cast [[RES_BUFFER]] to offset: [0], sizes: [2], strides: [1] : memref<2xf32> to memref<2xf32> 
+// CHECK:           return [[RES]] : memref<2xf32>
+// COM: Test GatherND with indices_shape[-1] < rank(data) - batch_dims
+func @test_gather_nd_2(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x1x2xi64>) -> tensor<2x1x2xf32> {
+  %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2x2xf32>, tensor<2x1x2xi64>) -> tensor<2x1x2xf32>
+  "std.return"(%0) : (tensor<2x1x2xf32>) -> ()
+// CHECK-LABEL:  func @test_gather_nd_2
+// CHECK-SAME:   ([[PARAM_0:%.+]]: memref<2x2x2xf32>, [[PARAM_1:%.+]]: memref<2x1x2xi64>) -> memref<2x1x2xf32> {
+// CHECK-DAG:       [[RESHAPED_INDICES:%.+]] = memref.reinterpret_cast [[PARAM_1]] to offset: [0], sizes: [1, 2, 2], strides: [4, 2, 1] : memref<2x1x2xi64> to memref<1x2x2xi64>
+// CHECK-DAG:       [[RESHAPED_DATA:%.+]] = memref.reinterpret_cast [[PARAM_0]] to offset: [0], sizes: [1, 2, 2, 2], strides: [8, 4, 2, 1] : memref<2x2x2xf32> to memref<1x2x2x2xf32>
+// CHECK-DAG:       [[RES_BUFFER:%.+]] = memref.alloc() : memref<4xf32>
+// CHECK:           [[CST_0_0:%.+]] = arith.constant 0 : index
+// CHECK:           [[CST_1_0:%.+]] = arith.constant 1 : index
+// CHECK:           [[RES_INDEX_BUFFER:%.+]] = memref.alloca() : memref<index>
+// CHECK:           krnl.store [[CST_0_0]], [[RES_INDEX_BUFFER]][] : memref<index>
+// CHECK:           [[LOOP_0:%.+]]:2 = krnl.define_loops 2
+// CHECK:           krnl.iterate([[LOOP_0]]#0, [[LOOP_0]]#1) with ([[LOOP_0]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0]]#1 -> [[I_1_:%.+]] = 0 to 2){
+// CHECK-DAG:         [[IV:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0]]#0, [[LOOP_0]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
+// CHECK-DAG:         [[CST_0_1:%.+]] = arith.constant 0 : index
+// CHECK:             [[LOAD_INDEX_1:%.+]] = krnl.load [[RESHAPED_INDICES]]{{.}}[[IV]]#0, [[IV]]#1, [[CST_0_1]]{{.}} : memref<1x2x2xi64>
+// CHECK-DAG:         [[INDEX_1:%.+]] = arith.index_cast [[LOAD_INDEX_1]] : i64 to index
+// CHECK-DAG:         [[CST_1_1:%.+]] = arith.constant 1 : index
+// CHECK:             [[LOAD_INDEX_2:%.+]] = krnl.load [[RESHAPED_INDICES]]{{.}}[[IV]]#0, [[IV]]#1, [[CST_1_1]]{{.}} : memref<1x2x2xi64>
+// CHECK-DAG:         [[INDEX_2:%.+]] = arith.index_cast [[LOAD_INDEX_2]] : i64 to index
+// CHECK-DAG:         [[CST_0_2:%.+]] = arith.constant 0 : index
+// CHECK-NOT: separator of consecutive DAGs
+// CHECK-DAG:         [[DATA_1:%.+]] = krnl.load [[RESHAPED_DATA]]{{.}}[[IV]]#0, [[INDEX_1]], [[INDEX_2]], [[CST_0_2]]{{.}} : memref<1x2x2x2xf32>
+// CHECK-DAG:         [[RES_INDEX_1:%.+]] = krnl.load [[RES_INDEX_BUFFER]][] : memref<index>
+// CHECK:             krnl.store [[DATA_1]], [[RES_BUFFER]]{{.}}[[RES_INDEX_1]]{{.}} : memref<4xf32>
+// CHECK:             [[PLUS_ONE:%.+]] = arith.addi [[RES_INDEX_1]], [[CST_1_0]] : index
+// CHECK:             krnl.store [[PLUS_ONE]], [[RES_INDEX_BUFFER]][] : memref<index>
+// CHECK:             [[CST_1_2:%.+]] = arith.constant 1 : index
+// CHECK-DAG:         [[DATA_2:%.+]] = krnl.load [[RESHAPED_DATA]]{{.}}[[IV]]#0, [[INDEX_1]], [[INDEX_2]], [[CST_1_2]]{{.}} : memref<1x2x2x2xf32>
+// CHECK-DAG:         [[RES_INDEX_2:%.+]] = krnl.load [[RES_INDEX_BUFFER]][] : memref<index>
+// CHECK:             krnl.store [[DATA_2]], [[RES_BUFFER]]{{.}}[[RES_INDEX_2]]{{.}} : memref<4xf32>
+// CHECK:             [[PLUS_ONE_1:%.+]] = arith.addi [[RES_INDEX_2]], [[CST_1_0]] : index
+// CHECK:             krnl.store [[PLUS_ONE_1]], [[RES_INDEX_BUFFER]][] : memref<index>
+// CHECK:           }
+// CHECK:           [[RES:%.+]] = memref.reinterpret_cast [[RES_BUFFER]] to offset: [0], sizes: [2, 1, 2], strides: [2, 2, 1] : memref<4xf32> to memref<2x1x2xf32>
+// CHECK:           return [[RES]] : memref<2x1x2xf32>
   func @test_reversesequence_1(%arg0: tensor<10x?xf32>, %arg1: tensor<10xi64>) -> tensor<*xf32> {
diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir
index eecd97512b..e0f9cc9a70 100644
--- a/test/mlir/onnx/onnx_shape_inference.mlir
+++ b/test/mlir/onnx/onnx_shape_inference.mlir
@@ -1568,6 +1568,62 @@ func @test_gather_negative_axis(%arg0 : tensor<3x3xf32>, %arg1 : tensor<1x2xi64>
   // CHECK: return [[RES]] : tensor<3x1x2xf32>
+// -----
+func @test_gather_nd_1(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2x2xi64>) -> tensor<*xf32> {
+  %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<2x2xi64>) -> tensor<*xf32>
+  "std.return"(%0) : (tensor<*xf32>) -> ()
+  // CHECK-LABEL: test_gather_nd_1
+  // CHECK: [[RES:%.+]] = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<2x2xi64>) -> tensor<2xf32>
+  // CHECK: return [[RES]] : tensor<2xf32>
+// -----
+func @test_gather_nd_2(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2x1xi64>) -> tensor<*xf32> {
+  %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<2x1xi64>) -> tensor<*xf32>
+  "std.return"(%0) : (tensor<*xf32>) -> ()
+  // CHECK-LABEL: test_gather_nd_2
+  // CHECK: [[RES:%.+]] = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2xf32>, tensor<2x1xi64>) -> tensor<2x2xf32>
+  // CHECK: return [[RES]] : tensor<2x2xf32>
+// -----
+func @test_gather_nd_3(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x2xi64>) -> tensor<*xf32> {
+  %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2x2xf32>, tensor<2x2xi64>) -> tensor<*xf32>
+  "std.return"(%0) : (tensor<*xf32>) -> ()
+  // CHECK-LABEL: test_gather_nd_3
+  // CHECK: [[RES:%.+]] = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2x2xf32>, tensor<2x2xi64>) -> tensor<2x2xf32>
+  // CHECK: return [[RES]] : tensor<2x2xf32>
+// -----
+func @test_gather_nd_4(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x1x2xi64>) -> tensor<*xf32> {
+  %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2x2xf32>, tensor<2x1x2xi64>) -> tensor<*xf32>
+  "std.return"(%0) : (tensor<*xf32>) -> ()
+  // CHECK-LABEL: test_gather_nd_4
+  // CHECK: [[RES:%.+]] = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 0 : si64} : (tensor<2x2x2xf32>, tensor<2x1x2xi64>) -> tensor<2x1x2xf32>
+  // CHECK: return [[RES]] : tensor<2x1x2xf32>
+// -----
+func @test_gather_nd_5(%arg0 : tensor<2x2x2xf32>, %arg1 : tensor<2x1xi64>) -> tensor<*xf32> {
+  %0 = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 1 : si64} : (tensor<2x2x2xf32>, tensor<2x1xi64>) -> tensor<*xf32>
+  "std.return"(%0) : (tensor<*xf32>) -> ()
+  // CHECK-LABEL: test_gather_nd_5
+  // CHECK: [[RES:%.+]] = "onnx.GatherND"(%arg0, %arg1) {batch_dims = 1 : si64} : (tensor<2x2x2xf32>, tensor<2x1xi64>) -> tensor<2x2xf32>
+  // CHECK: return [[RES]] : tensor<2x2xf32>
 // -----
 func @test_constant_of_shape_empty_tensor(%arg0 : tensor<0xi64>) -> tensor<*xf32> {
diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py
index 5d6ff15d5a..a15e7f254f 100755
--- a/utils/gen_onnx_mlir.py
+++ b/utils/gen_onnx_mlir.py
@@ -294,6 +294,7 @@
+    'GatherND',        