Skip to content

Commit

Permalink
Add new api for SetAPIRequestTimeout
Browse files Browse the repository at this point in the history
  • Loading branch information
sunhailin committed Jul 26, 2024
1 parent 4dc88a0 commit 393c271
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 184 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ func main() {

### Version

* version 2.0.5 - 2024/07/26
* Remove timeout for `TritonService` interface, use `SetAPIRequestTimeout` instead.
* Add new api for `SetAPIRequestTimeout`

* version 2.0.4 - 2024/07/09
* Update `W2NER` input feature problem.(Missing `MaxSeqLength` config)
* Code style fix. Reducing nil cases
Expand Down
52 changes: 24 additions & 28 deletions models/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ func (m *ModelService) SetSecondaryServerURL(url string) *ModelService {
return m
}

// SetAPIRequestTimeout set api request timeout
func (m *ModelService) SetAPIRequestTimeout(timeout time.Duration) *ModelService {
if m.TritonService != nil {
m.TritonService.SetAPIRequestTimeout(timeout)
}
return m
}

// SetJsonEncoder set json encoder
func (m *ModelService) SetJsonEncoder(encoder utils.JSONMarshal) *ModelService {
m.TritonService.SetJSONEncoder(encoder)
Expand All @@ -130,55 +138,43 @@ func (m *ModelService) SetJsonDecoder(decoder utils.JSONUnmarshal) *ModelService
//////////////////////////////////////////// Triton Service API Function ////////////////////////////////////////////

// CheckServerReady check server is ready.
func (m *ModelService) CheckServerReady(requestTimeout time.Duration) (bool, error) {
return m.TritonService.CheckServerReady(requestTimeout)
func (m *ModelService) CheckServerReady() (bool, error) {
return m.TritonService.CheckServerReady()
}

// CheckServerAlive check server is alive.
func (m *ModelService) CheckServerAlive(requestTimeout time.Duration) (bool, error) {
return m.TritonService.CheckServerAlive(requestTimeout)
func (m *ModelService) CheckServerAlive() (bool, error) {
return m.TritonService.CheckServerAlive()
}

// CheckModelReady check model is ready.
func (m *ModelService) CheckModelReady(
modelName, modelVersion string, requestTimeout time.Duration,
) (bool, error) {
return m.TritonService.CheckModelReady(modelName, modelVersion, requestTimeout)
func (m *ModelService) CheckModelReady(modelName, modelVersion string) (bool, error) {
return m.TritonService.CheckModelReady(modelName, modelVersion)
}

// GetServerMeta get server meta.
func (m *ModelService) GetServerMeta(
requestTimeout time.Duration,
) (*nvidia_inferenceserver.ServerMetadataResponse, error) {
return m.TritonService.ServerMetadata(requestTimeout)
func (m *ModelService) GetServerMeta() (*nvidia_inferenceserver.ServerMetadataResponse, error) {
return m.TritonService.ServerMetadata()
}

// GetModelMeta get model meta.
func (m *ModelService) GetModelMeta(
modelName, modelVersion string, requestTimeout time.Duration,
) (*nvidia_inferenceserver.ModelMetadataResponse, error) {
return m.TritonService.ModelMetadataRequest(modelName, modelVersion, requestTimeout)
func (m *ModelService) GetModelMeta(modelName, modelVersion string) (*nvidia_inferenceserver.ModelMetadataResponse, error) {
return m.TritonService.ModelMetadataRequest(modelName, modelVersion)
}

// GetAllModelInfo get all model info.
func (m *ModelService) GetAllModelInfo(
repoName string, isReady bool, requestTimeout time.Duration,
) (*nvidia_inferenceserver.RepositoryIndexResponse, error) {
return m.TritonService.ModelIndex(repoName, isReady, requestTimeout)
func (m *ModelService) GetAllModelInfo(repoName string, isReady bool) (*nvidia_inferenceserver.RepositoryIndexResponse, error) {
return m.TritonService.ModelIndex(repoName, isReady)
}

// GetModelConfig get model config.
func (m *ModelService) GetModelConfig(
modelName, modelVersion string, requestTimeout time.Duration,
) (interface{}, error) {
return m.TritonService.ModelConfiguration(modelName, modelVersion, requestTimeout)
func (m *ModelService) GetModelConfig(modelName, modelVersion string) (interface{}, error) {
return m.TritonService.ModelConfiguration(modelName, modelVersion)
}

// GetModelInferStats get model infer stats.
func (m *ModelService) GetModelInferStats(
modelName, modelVersion string, requestTimeout time.Duration,
) (*nvidia_inferenceserver.ModelStatisticsResponse, error) {
return m.TritonService.ModelInferStats(modelName, modelVersion, requestTimeout)
func (m *ModelService) GetModelInferStats(modelName, modelVersion string) (*nvidia_inferenceserver.ModelStatisticsResponse, error) {
return m.TritonService.ModelInferStats(modelName, modelVersion)
}

//////////////////////////////////////////// Triton Service API Function ////////////////////////////////////////////
6 changes: 2 additions & 4 deletions models/transformers/bert.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package transformers
import (
"encoding/binary"
"strings"
"time"

"github.com/sunhailin-Leo/triton-service-go/v2/models"
"github.com/sunhailin-Leo/triton-service-go/v2/nvidia_inferenceserver"
Expand Down Expand Up @@ -263,7 +262,6 @@ func (m *BertModelService) generateGRPCRequest(
func (m *BertModelService) ModelInfer(
inferData []string,
modelName, modelVersion string,
requestTimeout time.Duration,
params ...interface{},
) ([]interface{}, error) {
// Create request input/output tensors
Expand All @@ -276,7 +274,7 @@ func (m *BertModelService) ModelInfer(
return nil, utils.ErrEmptyGRPCRequestBody
}
return m.TritonService.ModelGRPCInfer(
inferInputs, inferOutputs, grpcRawInputs, modelName, modelVersion, requestTimeout,
inferInputs, inferOutputs, grpcRawInputs, modelName, modelVersion,
m.InferCallback, m, grpcInputData, params,
)
}
Expand All @@ -289,7 +287,7 @@ func (m *BertModelService) ModelInfer(
}
// HTTP Infer
return m.TritonService.ModelHTTPInfer(
httpRequestBody, modelName, modelVersion, requestTimeout,
httpRequestBody, modelName, modelVersion,
m.InferCallback, m, httpInputData, params,
)
}
Expand Down
14 changes: 4 additions & 10 deletions models/transformers/bert_w2ner.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package transformers

import (
"slices"
"time"

"github.com/sunhailin-Leo/triton-service-go/v2/models"
"github.com/sunhailin-Leo/triton-service-go/v2/nvidia_inferenceserver"
Expand Down Expand Up @@ -312,7 +311,6 @@ func (w *W2NerModelService) generateGRPCRequest(
func (w *W2NerModelService) ModelInfer(
inferData [][]string,
modelName, modelVersion string,
requestTimeout time.Duration,
params ...interface{},
) ([]interface{}, error) {
// Create request input/output tensors
Expand All @@ -325,10 +323,8 @@ func (w *W2NerModelService) ModelInfer(
if grpcRawInputs == nil {
return nil, utils.ErrEmptyGRPCRequestBody
}
return w.TritonService.ModelGRPCInfer(
inferInputs, inferOutputs, grpcRawInputs, modelName, modelVersion, requestTimeout,
w.InferCallback, w, grpcInputData, params,
)
return w.TritonService.ModelGRPCInfer(inferInputs, inferOutputs, grpcRawInputs, modelName, modelVersion,
w.InferCallback, w, grpcInputData, params)
}

httpRequestBody, httpInputData, err := w.generateHTTPRequest(inferData, inferInputs, inferOutputs)
Expand All @@ -339,10 +335,8 @@ func (w *W2NerModelService) ModelInfer(
return nil, utils.ErrEmptyHTTPRequestBody
}
// HTTP Infer
return w.TritonService.ModelHTTPInfer(
httpRequestBody, modelName, modelVersion, requestTimeout,
w.InferCallback, w, httpInputData, params,
)
return w.TritonService.ModelHTTPInfer(httpRequestBody, modelName, modelVersion, w.InferCallback,
w, httpInputData, params)
}

//////////////////////////////////////////// Triton Service API Function ////////////////////////////////////////////
Expand Down
Loading

0 comments on commit 393c271

Please sign in to comment.