From fdb012c8a1eea1b07845927bb16a37a542a78d1e Mon Sep 17 00:00:00 2001 From: sunhailin Date: Fri, 24 May 2024 18:17:50 +0800 Subject: [PATCH] fix generateGRPCRequest missing tensor shape --- models/transformers/bert.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/models/transformers/bert.go b/models/transformers/bert.go index ca17ba8..0794874 100644 --- a/models/transformers/bert.go +++ b/models/transformers/bert.go @@ -233,18 +233,21 @@ func (m *BertModelService) generateGRPCRequest( m.grpcSliceToLittleEndianByteSlice( m.MaxSeqLength, feature.TypeIDs, inferInputTensor[j].Datatype)..., ) + inferInputTensor[j].Shape = []int64{int64(len(inferDataArr)), int64(m.MaxSeqLength)} case ModelBertModelInputIdsKey: inputIdsBytes = append( inputIdsBytes, m.grpcSliceToLittleEndianByteSlice( m.MaxSeqLength, feature.TokenIDs, inferInputTensor[j].Datatype)..., ) + inferInputTensor[j].Shape = []int64{int64(len(inferDataArr)), int64(m.MaxSeqLength)} case ModelBertModelInputMaskKey: inputMaskBytes = append( inputMaskBytes, m.grpcSliceToLittleEndianByteSlice( m.MaxSeqLength, feature.Mask, inferInputTensor[j].Datatype)..., ) + inferInputTensor[j].Shape = []int64{int64(len(inferDataArr)), int64(m.MaxSeqLength)} } } batchModelInputObjs[i] = inputObject