diff --git a/README.md b/README.md index 6202206..1f4fcc3 100644 --- a/README.md +++ b/README.md @@ -146,6 +146,10 @@ func main() { ### Version +* version 2.0.5 - 2024/07/26 + * Remove timeout for `TritonService` interface, use `SetAPIRequestTimeout` instead. + * Add new api for `SetAPIRequestTimeout` + * version 2.0.4 - 2024/07/09 * Update `W2NER` input feature problem.(Missing `MaxSeqLength` config) * Code style fix. Reducing nil cases diff --git a/models/base.go b/models/base.go index 9bc96d8..5a8cf65 100644 --- a/models/base.go +++ b/models/base.go @@ -113,6 +113,14 @@ func (m *ModelService) SetSecondaryServerURL(url string) *ModelService { return m } +// SetAPIRequestTimeout set api request timeout +func (m *ModelService) SetAPIRequestTimeout(timeout time.Duration) *ModelService { + if m.TritonService != nil { + m.TritonService.SetAPIRequestTimeout(timeout) + } + return m +} + // SetJsonEncoder set json encoder func (m *ModelService) SetJsonEncoder(encoder utils.JSONMarshal) *ModelService { m.TritonService.SetJSONEncoder(encoder) @@ -130,55 +138,43 @@ func (m *ModelService) SetJsonDecoder(decoder utils.JSONUnmarshal) *ModelService //////////////////////////////////////////// Triton Service API Function //////////////////////////////////////////// // CheckServerReady check server is ready. -func (m *ModelService) CheckServerReady(requestTimeout time.Duration) (bool, error) { - return m.TritonService.CheckServerReady(requestTimeout) +func (m *ModelService) CheckServerReady() (bool, error) { + return m.TritonService.CheckServerReady() } // CheckServerAlive check server is alive. -func (m *ModelService) CheckServerAlive(requestTimeout time.Duration) (bool, error) { - return m.TritonService.CheckServerAlive(requestTimeout) +func (m *ModelService) CheckServerAlive() (bool, error) { + return m.TritonService.CheckServerAlive() } // CheckModelReady check model is ready. -func (m *ModelService) CheckModelReady( - modelName, modelVersion string, requestTimeout time.Duration, -) (bool, error) { - return m.TritonService.CheckModelReady(modelName, modelVersion, requestTimeout) +func (m *ModelService) CheckModelReady(modelName, modelVersion string) (bool, error) { + return m.TritonService.CheckModelReady(modelName, modelVersion) } // GetServerMeta get server meta. -func (m *ModelService) GetServerMeta( - requestTimeout time.Duration, -) (*nvidia_inferenceserver.ServerMetadataResponse, error) { - return m.TritonService.ServerMetadata(requestTimeout) +func (m *ModelService) GetServerMeta() (*nvidia_inferenceserver.ServerMetadataResponse, error) { + return m.TritonService.ServerMetadata() } // GetModelMeta get model meta. -func (m *ModelService) GetModelMeta( - modelName, modelVersion string, requestTimeout time.Duration, -) (*nvidia_inferenceserver.ModelMetadataResponse, error) { - return m.TritonService.ModelMetadataRequest(modelName, modelVersion, requestTimeout) +func (m *ModelService) GetModelMeta(modelName, modelVersion string) (*nvidia_inferenceserver.ModelMetadataResponse, error) { + return m.TritonService.ModelMetadataRequest(modelName, modelVersion) } // GetAllModelInfo get all model info. -func (m *ModelService) GetAllModelInfo( - repoName string, isReady bool, requestTimeout time.Duration, -) (*nvidia_inferenceserver.RepositoryIndexResponse, error) { - return m.TritonService.ModelIndex(repoName, isReady, requestTimeout) +func (m *ModelService) GetAllModelInfo(repoName string, isReady bool) (*nvidia_inferenceserver.RepositoryIndexResponse, error) { + return m.TritonService.ModelIndex(repoName, isReady) } // GetModelConfig get model config. -func (m *ModelService) GetModelConfig( - modelName, modelVersion string, requestTimeout time.Duration, -) (interface{}, error) { - return m.TritonService.ModelConfiguration(modelName, modelVersion, requestTimeout) +func (m *ModelService) GetModelConfig(modelName, modelVersion string) (interface{}, error) { + return m.TritonService.ModelConfiguration(modelName, modelVersion) } // GetModelInferStats get model infer stats. -func (m *ModelService) GetModelInferStats( - modelName, modelVersion string, requestTimeout time.Duration, -) (*nvidia_inferenceserver.ModelStatisticsResponse, error) { - return m.TritonService.ModelInferStats(modelName, modelVersion, requestTimeout) +func (m *ModelService) GetModelInferStats(modelName, modelVersion string) (*nvidia_inferenceserver.ModelStatisticsResponse, error) { + return m.TritonService.ModelInferStats(modelName, modelVersion) } //////////////////////////////////////////// Triton Service API Function //////////////////////////////////////////// diff --git a/models/transformers/bert.go b/models/transformers/bert.go index 0794874..ac2508f 100644 --- a/models/transformers/bert.go +++ b/models/transformers/bert.go @@ -3,7 +3,6 @@ package transformers import ( "encoding/binary" "strings" - "time" "github.com/sunhailin-Leo/triton-service-go/v2/models" "github.com/sunhailin-Leo/triton-service-go/v2/nvidia_inferenceserver" @@ -263,7 +262,6 @@ func (m *BertModelService) generateGRPCRequest( func (m *BertModelService) ModelInfer( inferData []string, modelName, modelVersion string, - requestTimeout time.Duration, params ...interface{}, ) ([]interface{}, error) { // Create request input/output tensors @@ -276,7 +274,7 @@ func (m *BertModelService) ModelInfer( return nil, utils.ErrEmptyGRPCRequestBody } return m.TritonService.ModelGRPCInfer( - inferInputs, inferOutputs, grpcRawInputs, modelName, modelVersion, requestTimeout, + inferInputs, inferOutputs, grpcRawInputs, modelName, modelVersion, m.InferCallback, m, grpcInputData, params, ) } @@ -289,7 +287,7 @@ func (m *BertModelService) ModelInfer( } // HTTP Infer return m.TritonService.ModelHTTPInfer( - httpRequestBody, modelName, modelVersion, requestTimeout, + httpRequestBody, modelName, modelVersion, m.InferCallback, m, httpInputData, params, ) } diff --git a/models/transformers/bert_w2ner.go b/models/transformers/bert_w2ner.go index 4fd6bb7..5db0c32 100644 --- a/models/transformers/bert_w2ner.go +++ b/models/transformers/bert_w2ner.go @@ -2,7 +2,6 @@ package transformers import ( "slices" - "time" "github.com/sunhailin-Leo/triton-service-go/v2/models" "github.com/sunhailin-Leo/triton-service-go/v2/nvidia_inferenceserver" @@ -312,7 +311,6 @@ func (w *W2NerModelService) generateGRPCRequest( func (w *W2NerModelService) ModelInfer( inferData [][]string, modelName, modelVersion string, - requestTimeout time.Duration, params ...interface{}, ) ([]interface{}, error) { // Create request input/output tensors @@ -325,10 +323,8 @@ func (w *W2NerModelService) ModelInfer( if grpcRawInputs == nil { return nil, utils.ErrEmptyGRPCRequestBody } - return w.TritonService.ModelGRPCInfer( - inferInputs, inferOutputs, grpcRawInputs, modelName, modelVersion, requestTimeout, - w.InferCallback, w, grpcInputData, params, - ) + return w.TritonService.ModelGRPCInfer(inferInputs, inferOutputs, grpcRawInputs, modelName, modelVersion, + w.InferCallback, w, grpcInputData, params) } httpRequestBody, httpInputData, err := w.generateHTTPRequest(inferData, inferInputs, inferOutputs) @@ -339,10 +335,8 @@ func (w *W2NerModelService) ModelInfer( return nil, utils.ErrEmptyHTTPRequestBody } // HTTP Infer - return w.TritonService.ModelHTTPInfer( - httpRequestBody, modelName, modelVersion, requestTimeout, - w.InferCallback, w, httpInputData, params, - ) + return w.TritonService.ModelHTTPInfer(httpRequestBody, modelName, modelVersion, w.InferCallback, + w, httpInputData, params) } //////////////////////////////////////////// Triton Service API Function //////////////////////////////////////////// diff --git a/nvidia_inferenceserver/triton_service_interface.go b/nvidia_inferenceserver/triton_service_interface.go index 16ba68d..ca608a3 100644 --- a/nvidia_inferenceserver/triton_service_interface.go +++ b/nvidia_inferenceserver/triton_service_interface.go @@ -33,23 +33,27 @@ const ( // DecoderFunc Infer Callback Function. type DecoderFunc func(response interface{}, params ...interface{}) ([]interface{}, error) -// TritonGRPCService Service interface. -type TritonGRPCService interface { +// TritonService Service interface. +type TritonService interface { + // JsonMarshal Json Encoder + JsonMarshal(v interface{}) ([]byte, error) + // JsonUnmarshal Json Decoder + JsonUnmarshal(data []byte, v interface{}) error + // CheckServerAlive Check triton inference server is alive. - CheckServerAlive(timeout time.Duration) (bool, error) + CheckServerAlive() (bool, error) // CheckServerReady Check triton inference server is ready. - CheckServerReady(timeout time.Duration) (bool, error) + CheckServerReady() (bool, error) // CheckModelReady Check triton inference server`s model is ready. - CheckModelReady(modelName, modelVersion string, timeout time.Duration) (bool, error) + CheckModelReady(modelName, modelVersion string) (bool, error) // ServerMetadata Get triton inference server metadata. - ServerMetadata(timeout time.Duration) (*ServerMetadataResponse, error) + ServerMetadata() (*ServerMetadataResponse, error) // ModelGRPCInfer Call triton inference server infer with GRPC ModelGRPCInfer( inferInputs []*ModelInferRequest_InferInputTensor, inferOutputs []*ModelInferRequest_InferRequestedOutputTensor, rawInputs [][]byte, modelName, modelVersion string, - timeout time.Duration, decoderFunc DecoderFunc, params ...interface{}, ) ([]interface{}, error) @@ -57,54 +61,44 @@ type TritonGRPCService interface { ModelHTTPInfer( requestBody []byte, modelName, modelVersion string, - timeout time.Duration, decoderFunc DecoderFunc, params ...interface{}, ) ([]interface{}, error) // ModelMetadataRequest Get triton inference server`s model metadata. - ModelMetadataRequest(modelName, modelVersion string, timeout time.Duration) (*ModelMetadataResponse, error) + ModelMetadataRequest(modelName, modelVersion string) (*ModelMetadataResponse, error) // ModelIndex Get triton inference server model index. - ModelIndex(isReady bool, timeout time.Duration) (*RepositoryIndexResponse, error) + ModelIndex(isReady bool) (*RepositoryIndexResponse, error) // ModelConfiguration Get triton inference server model configuration. - ModelConfiguration(modelName, modelVersion string, timeout time.Duration) (interface{}, error) + ModelConfiguration(modelName, modelVersion string) (interface{}, error) // ModelInferStats Get triton inference server model infer stats. - ModelInferStats(modelName, modelVersion string, timeout time.Duration) (*ModelStatisticsResponse, error) + ModelInferStats(modelName, modelVersion string) (*ModelStatisticsResponse, error) // ModelLoadWithHTTP Load model with http. - ModelLoadWithHTTP( - modelName string, modelConfigBody []byte, timeout time.Duration) (*RepositoryModelLoadResponse, error) + ModelLoadWithHTTP(modelName string, modelConfigBody []byte) (*RepositoryModelLoadResponse, error) // ModelLoadWithGRPC Load model with http. - ModelLoadWithGRPC( - repoName, modelName string, modelConfigBody map[string]*ModelRepositoryParameter, timeout time.Duration, - ) (*RepositoryModelLoadResponse, error) + ModelLoadWithGRPC(repoName, modelName string, modelConfigBody map[string]*ModelRepositoryParameter) (*RepositoryModelLoadResponse, error) // ModelUnloadWithHTTP Unload model with http. - ModelUnloadWithHTTP( - modelName string, modelConfigBody []byte, timeout time.Duration, - ) (*RepositoryModelUnloadResponse, error) + ModelUnloadWithHTTP(modelName string, modelConfigBody []byte) (*RepositoryModelUnloadResponse, error) // ModelUnloadWithGRPC Unload model with grpc. - ModelUnloadWithGRPC( - repoName, modelName string, modelConfigBody map[string]*ModelRepositoryParameter, timeout time.Duration, - ) (*RepositoryModelUnloadResponse, error) + ModelUnloadWithGRPC(repoName, modelName string, modelConfigBody map[string]*ModelRepositoryParameter) (*RepositoryModelUnloadResponse, error) // ShareMemoryStatus Show share memory / share cuda memory status. - ShareMemoryStatus(isCUDA bool, regionName string, timeout time.Duration) (interface{}, error) + ShareMemoryStatus(isCUDA bool, regionName string) (interface{}, error) // ShareCUDAMemoryRegister Register share cuda memory. - ShareCUDAMemoryRegister( - regionName string, cudaRawHandle []byte, cudaDeviceID int64, byteSize uint64, timeout time.Duration, - ) (interface{}, error) + ShareCUDAMemoryRegister(regionName string, cudaRawHandle []byte, cudaDeviceID int64, byteSize uint64) (interface{}, error) // ShareCUDAMemoryUnRegister Unregister share cuda memory - ShareCUDAMemoryUnRegister(regionName string, timeout time.Duration) (interface{}, error) + ShareCUDAMemoryUnRegister(regionName string) (interface{}, error) // ShareSystemMemoryRegister Register system share memory. - ShareSystemMemoryRegister( - regionName, cpuMemRegionKey string, byteSize, cpuMemOffset uint64, timeout time.Duration, - ) (interface{}, error) + ShareSystemMemoryRegister(regionName, cpuMemRegionKey string, byteSize, cpuMemOffset uint64) (interface{}, error) // ShareSystemMemoryUnRegister Unregister system share memory. - ShareSystemMemoryUnRegister(regionName string, timeout time.Duration) (interface{}, error) + ShareSystemMemoryUnRegister(regionName string) (interface{}, error) // GetModelTracingSetting get the current trace setting. - GetModelTracingSetting(modelName string, timeout time.Duration) (*TraceSettingResponse, error) + GetModelTracingSetting(modelName string) (*TraceSettingResponse, error) // SetModelTracingSetting set the current trace setting. - SetModelTracingSetting( - modelName string, settingMap map[string]*TraceSettingRequest_SettingValue, timeout time.Duration, - ) (*TraceSettingResponse, error) + SetModelTracingSetting(modelName string, settingMap map[string]*TraceSettingRequest_SettingValue) (*TraceSettingResponse, error) + // SetSecondaryServerURL Set secondary server url + SetSecondaryServerURL(url string) + // SetAPIRequestTimeout Set API request timeout. + SetAPIRequestTimeout(timeout time.Duration) // ShutdownTritonConnection close client connection. ShutdownTritonConnection() (disconnectionErr error) } @@ -124,10 +118,10 @@ type TritonClientService struct { grpcConn *grpc.ClientConn grpcClient GRPCInferenceServiceClient httpClient *fasthttp.Client + apiTimeout time.Duration // Default: json.Marshal JSONEncoder utils.JSONMarshal - // Default: json.Unmarshal JSONDecoder utils.JSONUnmarshal } @@ -174,9 +168,7 @@ func (t *TritonClientService) getServerURL() string { } // makeHTTPPostRequestWithDoTimeout make http post request with timeout. -func (t *TritonClientService) makeHTTPPostRequestWithDoTimeout( - uri string, reqBody []byte, timeout time.Duration, -) (*fasthttp.Response, error) { +func (t *TritonClientService) makeHTTPPostRequestWithDoTimeout(uri string, reqBody []byte) (*fasthttp.Response, error) { requestObj := t.acquireHTTPRequest(fasthttp.MethodPost) responseObj := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(requestObj) @@ -185,22 +177,20 @@ func (t *TritonClientService) makeHTTPPostRequestWithDoTimeout( if reqBody != nil { requestObj.SetBody(reqBody) } - if httpErr := t.httpClient.DoTimeout(requestObj, responseObj, timeout); httpErr != nil { + if httpErr := t.httpClient.DoTimeout(requestObj, responseObj, t.apiTimeout); httpErr != nil { return responseObj, httpErr } return responseObj, nil } // makeHTTPGetRequestWithDoTimeout make http get request with timeout. -func (t *TritonClientService) makeHTTPGetRequestWithDoTimeout( - uri string, timeout time.Duration, -) (*fasthttp.Response, error) { +func (t *TritonClientService) makeHTTPGetRequestWithDoTimeout(uri string) (*fasthttp.Response, error) { requestObj := t.acquireHTTPRequest(fasthttp.MethodGet) responseObj := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(requestObj) requestObj.SetRequestURI(uri) - if httpErr := t.httpClient.DoTimeout(requestObj, responseObj, timeout); httpErr != nil { + if httpErr := t.httpClient.DoTimeout(requestObj, responseObj, t.apiTimeout); httpErr != nil { return responseObj, httpErr } return responseObj, nil @@ -212,9 +202,8 @@ func (t *TritonClientService) modelGRPCInfer( inferOutputs []*ModelInferRequest_InferRequestedOutputTensor, rawInputs [][]byte, modelName, modelVersion string, - timeout time.Duration, ) (*ModelInferResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // Create infer request for specific model/version. modelInferRequest := ModelInferRequest{ @@ -227,14 +216,17 @@ func (t *TritonClientService) modelGRPCInfer( // Get infer response. modelInferResponse, inferErr := t.grpcClient.ModelInfer(ctx, &modelInferRequest) if inferErr != nil { - return nil, errors.New("inferErr: " + inferErr.Error()) + return nil, errors.New("[GRPC]inferErr: " + inferErr.Error()) } return modelInferResponse, nil } // httpErrorHandler HTTP Error Handler. func (t *TritonClientService) httpErrorHandler(statusCode int, httpErr error) error { - return errors.New("[HTTP]code: " + strconv.Itoa(statusCode) + "; error: " + httpErr.Error()) + if httpErr != nil { + return errors.New("[HTTP]code: " + strconv.Itoa(statusCode) + "; error: " + httpErr.Error()) + } + return nil } // grpcErrorHandler GRPC Error Handler. @@ -281,15 +273,13 @@ func (t *TritonClientService) SetJsonDecoder(decoder utils.JSONUnmarshal) *Trito func (t *TritonClientService) ModelHTTPInfer( requestBody []byte, modelName, modelVersion string, - timeout time.Duration, decoderFunc DecoderFunc, params ...interface{}, ) ([]interface{}, error) { // get infer response. modelInferResponse, inferErr := t.makeHTTPPostRequestWithDoTimeout( t.getServerURL()+TritonAPIForModelPrefix+modelName+TritonAPIForModelVersionPrefix+modelVersion+"/infer", - requestBody, - timeout) + requestBody) defer fasthttp.ReleaseResponse(modelInferResponse) if modelInferResponse == nil { @@ -316,13 +306,12 @@ func (t *TritonClientService) ModelGRPCInfer( inferOutputs []*ModelInferRequest_InferRequestedOutputTensor, rawInputs [][]byte, modelName, modelVersion string, - timeout time.Duration, decoderFunc DecoderFunc, params ...interface{}, ) ([]interface{}, error) { // Get infer response. modelInferResponse, inferErr := t.modelGRPCInfer( - inferInputs, inferOutputs, rawInputs, modelName, modelVersion, timeout) + inferInputs, inferOutputs, rawInputs, modelName, modelVersion) if inferErr != nil { return nil, t.grpcErrorHandler(inferErr) } @@ -335,9 +324,9 @@ func (t *TritonClientService) ModelGRPCInfer( } // CheckServerAlive check server is alive. -func (t *TritonClientService) CheckServerAlive(timeout time.Duration) (bool, error) { +func (t *TritonClientService) CheckServerAlive() (bool, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // server alive. @@ -347,7 +336,7 @@ func (t *TritonClientService) CheckServerAlive(timeout time.Duration) (bool, err } return serverLiveResponse.Live, nil } - apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIForServerIsLive, nil, timeout) + apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIForServerIsLive, nil) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return false, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -359,9 +348,9 @@ func (t *TritonClientService) CheckServerAlive(timeout time.Duration) (bool, err } // CheckServerReady check server is ready. -func (t *TritonClientService) CheckServerReady(timeout time.Duration) (bool, error) { +func (t *TritonClientService) CheckServerReady() (bool, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // server ready. @@ -371,7 +360,7 @@ func (t *TritonClientService) CheckServerReady(timeout time.Duration) (bool, err } return serverReadyResponse.Ready, nil } - apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIForServerIsReady, nil, timeout) + apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIForServerIsReady, nil) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return false, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -383,9 +372,9 @@ func (t *TritonClientService) CheckServerReady(timeout time.Duration) (bool, err } // CheckModelReady check model is ready. -func (t *TritonClientService) CheckModelReady(modelName, modelVersion string, timeout time.Duration) (bool, error) { +func (t *TritonClientService) CheckModelReady(modelName, modelVersion string) (bool, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // model ready @@ -398,7 +387,7 @@ func (t *TritonClientService) CheckModelReady(modelName, modelVersion string, ti } apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout( t.getServerURL()+TritonAPIForModelPrefix+modelName+TritonAPIForModelVersionPrefix+modelVersion+"/ready", - nil, timeout) + nil) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return false, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -410,16 +399,16 @@ func (t *TritonClientService) CheckModelReady(modelName, modelVersion string, ti } // ServerMetadata Get server metadata. -func (t *TritonClientService) ServerMetadata(timeout time.Duration) (*ServerMetadataResponse, error) { +func (t *TritonClientService) ServerMetadata() (*ServerMetadataResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // server metadata serverMetadataResponse, serverMetaErr := t.grpcClient.ServerMetadata(ctx, &ServerMetadataRequest{}) return serverMetadataResponse, t.grpcErrorHandler(serverMetaErr) } - apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIPrefix, nil, timeout) + apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIPrefix, nil) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -435,11 +424,9 @@ func (t *TritonClientService) ServerMetadata(timeout time.Duration) (*ServerMeta } // ModelMetadataRequest Get model metadata. -func (t *TritonClientService) ModelMetadataRequest( - modelName, modelVersion string, timeout time.Duration, -) (*ModelMetadataResponse, error) { +func (t *TritonClientService) ModelMetadataRequest(modelName, modelVersion string) (*ModelMetadataResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // model metadata @@ -448,8 +435,7 @@ func (t *TritonClientService) ModelMetadataRequest( return modelMetadataResponse, t.grpcErrorHandler(modelMetaErr) } apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout( - t.getServerURL()+TritonAPIForModelPrefix+modelName+TritonAPIForModelVersionPrefix+modelVersion, - nil, timeout) + t.getServerURL()+TritonAPIForModelPrefix+modelName+TritonAPIForModelVersionPrefix+modelVersion, nil) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -465,11 +451,9 @@ func (t *TritonClientService) ModelMetadataRequest( } // ModelIndex Get model repo index. -func (t *TritonClientService) ModelIndex( - repoName string, isReady bool, timeout time.Duration, -) (*RepositoryIndexResponse, error) { +func (t *TritonClientService) ModelIndex(repoName string, isReady bool) (*RepositoryIndexResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // The name of the repository. If empty the index is returned for all repositories. @@ -481,7 +465,7 @@ func (t *TritonClientService) ModelIndex( if jsonEncodeErr != nil { return nil, jsonEncodeErr } - apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIForRepoIndex, reqBody, timeout) + apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIForRepoIndex, reqBody) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -497,11 +481,9 @@ func (t *TritonClientService) ModelIndex( } // ModelConfiguration Get model configuration. -func (t *TritonClientService) ModelConfiguration( - modelName, modelVersion string, timeout time.Duration, -) (*ModelConfigResponse, error) { +func (t *TritonClientService) ModelConfiguration(modelName, modelVersion string) (*ModelConfigResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() modelConfigResponse, getModelConfigErr := t.grpcClient.ModelConfig( @@ -509,8 +491,8 @@ func (t *TritonClientService) ModelConfiguration( return modelConfigResponse, t.grpcErrorHandler(getModelConfigErr) } apiResp, httpErr := t.makeHTTPGetRequestWithDoTimeout( - t.getServerURL()+TritonAPIForModelPrefix+modelName+ - TritonAPIForModelVersionPrefix+modelVersion+"/config", timeout) + t.getServerURL() + TritonAPIForModelPrefix + modelName + + TritonAPIForModelVersionPrefix + modelVersion + "/config") defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -526,11 +508,9 @@ func (t *TritonClientService) ModelConfiguration( } // ModelInferStats Get Model infer stats. -func (t *TritonClientService) ModelInferStats( - modelName, modelVersion string, timeout time.Duration, -) (*ModelStatisticsResponse, error) { +func (t *TritonClientService) ModelInferStats(modelName, modelVersion string) (*ModelStatisticsResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() modelStatisticsResponse, getInferStatsErr := t.grpcClient.ModelStatistics( @@ -538,8 +518,7 @@ func (t *TritonClientService) ModelInferStats( return modelStatisticsResponse, t.grpcErrorHandler(getInferStatsErr) } apiResp, httpErr := t.makeHTTPGetRequestWithDoTimeout( - t.getServerURL()+TritonAPIForModelPrefix+modelName+TritonAPIForModelVersionPrefix+modelVersion+"/stats", - timeout) + t.getServerURL() + TritonAPIForModelPrefix + modelName + TritonAPIForModelVersionPrefix + modelVersion + "/stats") defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -558,11 +537,9 @@ func (t *TritonClientService) ModelInferStats( // ModelLoadWithHTTP Load Model with http // modelConfigBody ==> // https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_model_repository.md#examples -func (t *TritonClientService) ModelLoadWithHTTP( - modelName string, modelConfigBody []byte, timeout time.Duration, -) (*RepositoryModelLoadResponse, error) { +func (t *TritonClientService) ModelLoadWithHTTP(modelName string, modelConfigBody []byte) (*RepositoryModelLoadResponse, error) { apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout( - t.getServerURL()+TritonAPIForRepoModelPrefix+modelName+"/load", modelConfigBody, timeout) + t.getServerURL()+TritonAPIForRepoModelPrefix+modelName+"/load", modelConfigBody) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -578,10 +555,8 @@ func (t *TritonClientService) ModelLoadWithHTTP( } // ModelLoadWithGRPC Load Model with grpc. -func (t *TritonClientService) ModelLoadWithGRPC( - repoName, modelName string, modelConfigBody map[string]*ModelRepositoryParameter, timeout time.Duration, -) (*RepositoryModelLoadResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) +func (t *TritonClientService) ModelLoadWithGRPC(repoName, modelName string, modelConfigBody map[string]*ModelRepositoryParameter) (*RepositoryModelLoadResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // The name of the repository to load from. If empty the model is loaded from any repository. loadResponse, loadErr := t.grpcClient.RepositoryModelLoad(ctx, &RepositoryModelLoadRequest{ @@ -594,11 +569,9 @@ func (t *TritonClientService) ModelLoadWithGRPC( // ModelUnloadWithHTTP Unload model with http // modelConfigBody if not is nil. -func (t *TritonClientService) ModelUnloadWithHTTP( - modelName string, modelConfigBody []byte, timeout time.Duration, -) (*RepositoryModelUnloadResponse, error) { +func (t *TritonClientService) ModelUnloadWithHTTP(modelName string, modelConfigBody []byte) (*RepositoryModelUnloadResponse, error) { apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout( - t.getServerURL()+TritonAPIForRepoModelPrefix+modelName+"/unload", modelConfigBody, timeout) + t.getServerURL()+TritonAPIForRepoModelPrefix+modelName+"/unload", modelConfigBody) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -616,10 +589,8 @@ func (t *TritonClientService) ModelUnloadWithHTTP( // ModelUnloadWithGRPC Unload model with grpc // modelConfigBody if not is nil. -func (t *TritonClientService) ModelUnloadWithGRPC( - repoName, modelName string, modelConfigBody map[string]*ModelRepositoryParameter, timeout time.Duration, -) (*RepositoryModelUnloadResponse, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) +func (t *TritonClientService) ModelUnloadWithGRPC(repoName, modelName string, modelConfigBody map[string]*ModelRepositoryParameter) (*RepositoryModelUnloadResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() unloadResponse, unloadErr := t.grpcClient.RepositoryModelUnload(ctx, &RepositoryModelUnloadRequest{ @@ -632,11 +603,9 @@ func (t *TritonClientService) ModelUnloadWithGRPC( // ShareMemoryStatus Get share memory / cuda memory status. // Response: CudaSharedMemoryStatusResponse / SystemSharedMemoryStatusResponse. -func (t *TritonClientService) ShareMemoryStatus( - isCUDA bool, regionName string, timeout time.Duration, -) (interface{}, error) { +func (t *TritonClientService) ShareMemoryStatus(isCUDA bool, regionName string) (interface{}, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() if isCUDA { @@ -661,7 +630,7 @@ func (t *TritonClientService) ShareMemoryStatus( } else { uri = t.getServerURL() + TritonAPIForSystemMemoryRegionPrefix + regionName + "/status" } - apiResp, httpErr := t.makeHTTPGetRequestWithDoTimeout(uri, timeout) + apiResp, httpErr := t.makeHTTPGetRequestWithDoTimeout(uri) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return false, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -685,11 +654,9 @@ func (t *TritonClientService) ShareMemoryStatus( } // ShareCUDAMemoryRegister cuda share memory register. -func (t *TritonClientService) ShareCUDAMemoryRegister( - regionName string, cudaRawHandle []byte, cudaDeviceID int64, byteSize uint64, timeout time.Duration, -) (*CudaSharedMemoryRegisterResponse, error) { +func (t *TritonClientService) ShareCUDAMemoryRegister(regionName string, cudaRawHandle []byte, cudaDeviceID int64, byteSize uint64) (*CudaSharedMemoryRegisterResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // CUDA Memory @@ -709,7 +676,7 @@ func (t *TritonClientService) ShareCUDAMemoryRegister( return nil, jsonEncodeErr } apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout( - t.getServerURL()+TritonAPIForCudaMemoryRegionPrefix+regionName+"/register", reqBody, timeout) + t.getServerURL()+TritonAPIForCudaMemoryRegionPrefix+regionName+"/register", reqBody) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -725,11 +692,9 @@ func (t *TritonClientService) ShareCUDAMemoryRegister( } // ShareCUDAMemoryUnRegister cuda share memory unregister. -func (t *TritonClientService) ShareCUDAMemoryUnRegister( - regionName string, timeout time.Duration, -) (*CudaSharedMemoryUnregisterResponse, error) { +func (t *TritonClientService) ShareCUDAMemoryUnRegister(regionName string) (*CudaSharedMemoryUnregisterResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // CUDA Memory @@ -738,7 +703,7 @@ func (t *TritonClientService) ShareCUDAMemoryUnRegister( return cudaSharedMemoryUnRegisterResponse, t.grpcErrorHandler(unRegisterErr) } apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout( - t.getServerURL()+TritonAPIForCudaMemoryRegionPrefix+regionName+"/unregister", nil, timeout) + t.getServerURL()+TritonAPIForCudaMemoryRegionPrefix+regionName+"/unregister", nil) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -754,11 +719,9 @@ func (t *TritonClientService) ShareCUDAMemoryUnRegister( } // ShareSystemMemoryRegister system share memory register. -func (t *TritonClientService) ShareSystemMemoryRegister( - regionName, cpuMemRegionKey string, byteSize, cpuMemOffset uint64, timeout time.Duration, -) (*SystemSharedMemoryRegisterResponse, error) { +func (t *TritonClientService) ShareSystemMemoryRegister(regionName, cpuMemRegionKey string, byteSize, cpuMemOffset uint64) (*SystemSharedMemoryRegisterResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // System Memory @@ -778,7 +741,7 @@ func (t *TritonClientService) ShareSystemMemoryRegister( return nil, jsonEncodeErr } apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout( - t.getServerURL()+TritonAPIForSystemMemoryRegionPrefix+regionName+"/register", reqBody, timeout) + t.getServerURL()+TritonAPIForSystemMemoryRegionPrefix+regionName+"/register", reqBody) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -794,11 +757,9 @@ func (t *TritonClientService) ShareSystemMemoryRegister( } // ShareSystemMemoryUnRegister system share memory unregister. -func (t *TritonClientService) ShareSystemMemoryUnRegister( - regionName string, timeout time.Duration, -) (*SystemSharedMemoryUnregisterResponse, error) { +func (t *TritonClientService) ShareSystemMemoryUnRegister(regionName string) (*SystemSharedMemoryUnregisterResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // System Memory @@ -807,7 +768,7 @@ func (t *TritonClientService) ShareSystemMemoryUnRegister( return systemSharedMemoryUnRegisterResponse, t.grpcErrorHandler(unRegisterErr) } apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout( - t.getServerURL()+TritonAPIForSystemMemoryRegionPrefix+regionName+"/unregister", nil, timeout) + t.getServerURL()+TritonAPIForSystemMemoryRegionPrefix+regionName+"/unregister", nil) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -823,11 +784,9 @@ func (t *TritonClientService) ShareSystemMemoryUnRegister( } // GetModelTracingSetting get model tracing setting. -func (t *TritonClientService) GetModelTracingSetting( - modelName string, timeout time.Duration, -) (*TraceSettingResponse, error) { +func (t *TritonClientService) GetModelTracingSetting(modelName string) (*TraceSettingResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() // Tracing @@ -836,7 +795,7 @@ func (t *TritonClientService) GetModelTracingSetting( return traceSettingResponse, t.grpcErrorHandler(getTraceSettingErr) } apiResp, httpErr := t.makeHTTPGetRequestWithDoTimeout( - t.getServerURL()+TritonAPIForModelPrefix+modelName+"/trace/setting", timeout) + t.getServerURL() + TritonAPIForModelPrefix + modelName + "/trace/setting") defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -855,10 +814,10 @@ func (t *TritonClientService) GetModelTracingSetting( // Param: settingMap ==> // https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_trace.md#trace-setting-response-json-object func (t *TritonClientService) SetModelTracingSetting( - modelName string, settingMap map[string]*TraceSettingRequest_SettingValue, timeout time.Duration, + modelName string, settingMap map[string]*TraceSettingRequest_SettingValue, ) (*TraceSettingResponse, error) { if t.grpcClient != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), t.apiTimeout) defer cancel() traceSettingResponse, setTraceSettingErr := t.grpcClient.TraceSetting( @@ -870,8 +829,7 @@ func (t *TritonClientService) SetModelTracingSetting( if jsonEncodeErr != nil { return nil, jsonEncodeErr } - apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout( - t.getServerURL()+TritonAPIForModelPrefix+modelName+"/trace/setting", reqBody, timeout) + apiResp, httpErr := t.makeHTTPPostRequestWithDoTimeout(t.getServerURL()+TritonAPIForModelPrefix+modelName+"/trace/setting", reqBody) defer fasthttp.ReleaseResponse(apiResp) if apiResp == nil { return nil, t.httpErrorHandler(http.StatusInternalServerError, utils.ErrApiRespNil) @@ -891,6 +849,11 @@ func (t *TritonClientService) SetSecondaryServerURL(url string) { t.secondaryServerURL = url } +// SetAPIRequestTimeout Set API request timeout. +func (t *TritonClientService) SetAPIRequestTimeout(timeout time.Duration) { + t.apiTimeout = timeout +} + // ShutdownTritonConnection shutdown http and grpc connection. func (t *TritonClientService) ShutdownTritonConnection() (disconnectionErr error) { if t.grpcConn != nil { @@ -920,6 +883,7 @@ func NewTritonClientWithOnlyGRPC(grpcConn *grpc.ClientConn) *TritonClientService client := &TritonClientService{ grpcConn: grpcConn, grpcClient: NewGRPCInferenceServiceClient(grpcConn), + apiTimeout: DefaultHTTPClientReadTimeout, JSONEncoder: json.Marshal, JSONDecoder: json.Unmarshal, } @@ -927,13 +891,12 @@ func NewTritonClientWithOnlyGRPC(grpcConn *grpc.ClientConn) *TritonClientService } // NewTritonClientForAll init triton client with http and grpc. -func NewTritonClientForAll( - httpServerURL string, httpClient *fasthttp.Client, grpcConn *grpc.ClientConn, -) *TritonClientService { +func NewTritonClientForAll(httpServerURL string, httpClient *fasthttp.Client, grpcConn *grpc.ClientConn) *TritonClientService { client := &TritonClientService{ serverURL: httpServerURL, grpcConn: grpcConn, grpcClient: NewGRPCInferenceServiceClient(grpcConn), + apiTimeout: DefaultHTTPClientReadTimeout, JSONEncoder: json.Marshal, JSONDecoder: json.Unmarshal, }