推荐使用PAI-EAS提供的官方SDK进行服务调用,从而有效减少编写调用逻辑的时间并提高调用稳定性。本文介绍官方Golang SDK接口详情,并以常见类型的输入输出为例,提供了使用Golang SDK进行服务调用的完整程序示例。

背景信息

使用Golang SDK进行服务调用时,由于在编译代码时,Golang的包管理工具会自动从Github上将Golang SDK的代码下载到本地,因此您无需提前安装Golang SDK。如果您需要自定义部分调用逻辑,可以先下载Golang SDK代码,再对其进行修改。

接口列表

接口 描述
PredictClient NewPredictClient(endpoint string, serviceName string) *PredictClient
  • 功能:PredictClient类构造函数。
  • 参数:
    • endpoint:必填,表示服务端的Endpoint地址。对于普通服务,将其设置为默认网关Endpoint。
    • serviceName:必填,表示服务名称。
  • 返回值:创建的PredictClient对象。
SetEndpoint(endpointName string)
  • 功能:设置服务的Endpoint。
  • 参数:endpointName 表示服务端的Endpoint地址。对于普通服务,将其设置为默认网关Endpoint。
SetServiceName(serviceName string)
  • 功能:设置请求的服务名称。
  • 参数:serviceName表示请求的服务名称。
SetEndpointType(endpointType string)
  • 功能:设置服务端的网关类型。
  • 参数:endpointType表示网关类型。系统支持以下网关类型:
    • "DEFAULT":默认网关。如果不指定网关类型,默认为该类型。
    • "DIRECT":使用高速直连通道访问服务。
SetToken(token string)
  • 功能:设置服务访问的Token。
  • 参数:token表示访问服务时使用的鉴权Token。
SetHttpTransport(transport *http.Transport)
  • 功能:设置HTTP客户端的Transport属性。
  • 参数:transport表示发送HTTP请求时使用的Transport对象。
SetRetryCount(max_retry_count int)
  • 功能:设置请求失败重试次数。
  • 参数:max_retry_count表示请求失败后重连的次数,默认为5。
    说明 对于服务端进程异常、服务器异常或网关长连接断开等情况导致的个别请求失败,均需要客户端重新发送请求。因此,请勿将该参数设置为0。
SetTimeout(timeout int)
  • 功能:设置请求的超时时间。
  • 参数:timeout表示请求的超时时间,单位为ms,默认值为5000。
Init() 对PredictClient对象进行初始化。在上述设置参数的接口执行完成后,需要调用Init()接口才能生效。
Predict(request Request) Response
  • 功能:向在线预测服务提交一个预测请求。
  • 参数:Request对象是interface(StringRequest, TFRequest,TorchRequest)
  • 返回值:Response对象是interface(StringResponse, TFResponse,TorchResponse)
StringPredict(request string) string
  • 功能:向在线预测服务提交一个预测请求。
  • 参数:request对象表示待发送的请求字符串。
  • 返回值:STRING类型的服务响应。
TorchPredict(request TorchRequest) TorchResponse
  • 功能:向在线预测服务提交一个PyTorch预测请求。
  • 参数:request表示TorchRequest类的对象。
  • 返回值:对应的TorchResponse。
TFPredict(request TFRequest) TFResponse
  • 功能:向在线预测服务提交一个预测请求。
  • 参数:request表示TFRequest类的对象。
  • 返回值:对应的TFResponse。
TFRequest TFRequest(signatureName string)
  • 功能:TFRequest类的构建函数。
  • 参数:signatureName表示请求模型的Signature Name。
AddFeed(?)(inputName string, shape []int64{}, content []?)
  • 功能:请求TensorFlow的在线预测服务模型时,设置需要输入的Tensor。
  • 参数:
    • inputName:表示输入Tensor的别名。
    • shape:表示输入Tensor的TensorShape。
    • content:表示输入的Tensor的内容,通过一维数组展开表示。支持的类型包括INT32、INT64、FLOAT32、FLOAT64、STRING及BOOL,该接口名称与具体类型相关,例如AddFeedInt32()。如果需要其它数据类型,则可以参考代码自行通过PB格式构造。
AddFetch(outputName string)
  • 功能:请求TensorFlow的在线预测服务模型时,设置需要输出Tensor的别名。
  • 参数:outputName表示待获取的输出Tensor的别名。

    对于SavedModel模型,该参数可选。如果未设置,则输出所有的outputs。

    对于Frozen Model,该参数必选。

TFResponse GetTensorShape(outputName string) []int64
  • 功能:获得指定别名。的输出Tensor的TensorShape。
  • 参数:outputName表示待获取输出Shape的Tensor别名。
  • 返回值:返回的Tensor Shape,各个维度以数组形式表示。
Get(?)Val(outputName string) [](?)
  • 功能:获取输出Tensor的数据向量,输出结果以一维数组的形式保存。您可以配套使用GetTensorShape()接口,获取对应Tensor的Shape,将其还原成所需的多维Tensor。支持的类型包括FLOAT、DOUBLE、INT、INT64、STRING及BOOL,接口名称与具体类型相关,例如GetFloatVal()
  • 参数:outputName表示待获取输出数据的Tensor别名。
  • 返回值:输出Tensor的数据展开成的一维数组。
TorchRequest TorchRequest() TFRequest类的构建函数。
AddFeed(?)(index int, shape []int64{}, content []?)
  • 功能:请求PyTorch的在线预测服务模型时,设置需要输入的Tensor。
  • 参数:
    • index:表示待输入的Tensor下标。
    • shape:表示输入Tensor的TensorShape。
    • content:表示输入Tensor的内容,通过一维数组展开表示。支持的类型包括INT32、INT64、FLOAT32及FLOAT64,该接口名称与具体类型相关,例如AddFeedInt32()。如果需要其它数据类型,则可以参考代码自行通过PB格式构造。
AddFetch(outputIndex int)
  • 功能:请求PyTorch的在线预测服务模型时,设置需要输出的Tensor的Index。该接口为可选,如果您没有调用该接口设置输出Tensor的Index,则输出所有的outputs。
  • 参数:outputIndex表示输出Tensor的Index。
TorchResponse GetTensorShape(outputIndex int) []int64
  • 功能:获得指定下标的输出Tensor的TensorShape。
  • 参数:outputName表示待获取输出Shape的Tensor别名。
  • 返回值:返回的Tensor Shape,各个维度以数组形式表示。
Get(?)Val(outputIndex int) [](?)
  • 功能:获取输出Tensor的数据向量,输出结果以一维数组的形式保存。您可以配套使用GetTensorShape()接口获取对应Tensor的Shape,将其还原成所需的多维Tensor。支持的类型包括FLOAT、DOUBLE、INT及INT64,接口名称与具体类型相关,例如GetFloatVal()
  • 参数:outputIndex表示待获取输出数据Tensor的下标。
  • 返回值:输出Tensor的数据展开成的一维数组。

程序示例

  • 字符串输入输出示例
    对于使用自定义Processor部署服务的用户而言,通常采用字符串进行服务调用(例如,PMML模型服务的调用),具体的Demo程序如下。
    package main
    
    import (
            "fmt"
            "github.com/pai-eas/eas-golang-sdk/eas"
    )
    
    func main() {
        client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "scorecard_pmml_example")
        client.SetToken("YWFlMDYyZDNmNTc3M2I3MzMwYmY0MmYwM2Y2MTYxMTY4NzBkNzdj****")
        client.Init()
        req := "[{\"fea1\": 1, \"fea2\": 2}]"
        for i := 0; i < 100; i++ {
            resp, err := client.StringPredict(req)
            if err != nil {
                fmt.Printf("failed to predict: %v\n", err.Error())
            } else {
                fmt.Printf("%v\n", resp)
            }
        }
    }
  • TensorFlow输入输出示例
    使用TensorFlow的用户,需要将TFRequest和TFResponse分别作为输入和输出数据格式,具体Demo示例如下。
    package main
    
    import (
            "fmt"
            "github.com/pai-eas/eas-golang-sdk/eas"
    )
    
    func main() {
        client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "mnist_saved_model_example")
        client.SetToken("YTg2ZjE0ZjM4ZmE3OTc0NzYxZDMyNmYzMTJjZTQ1YmU0N2FjMTAy****")
        client.Init()
    
        tfreq := eas.TFRequest{}
        tfreq.SetSignatureName("predict_images")
        tfreq.AddFeedFloat32("images", []int64{1, 784}, make([]float32, 784))
    
        for i := 0; i < 100; i++ {
            resp, err := client.TFPredict(tfreq)
            if err != nil {
                fmt.Printf("failed to predict: %v", err)
            } else {
                fmt.Printf("%v\n", resp)
            }
        }
    }
  • PyTorch输入输出示例
    使用PyTorch的用户,需要将TorchRequest和TorchResponse分别作为输入和输出数据格式,具体Demo示例如下。
    package main
    
    import (
            "fmt"
            "github.com/pai-eas/eas-golang-sdk/eas"
    )
    
    func main() {
        client := eas.NewPredictClient("182848887922****.cn-shanghai.pai-eas.aliyuncs.com", "pytorch_resnet_example")
        client.SetTimeout(500)
        client.SetToken("ZjdjZDg1NWVlMWI2NTU5YzJiMmY5ZmE5OTBmYzZkMjI0YjlmYWVl****")
        client.Init()
        req := eas.TorchRequest{}
        req.AddFeedFloat32(0, []int64{1, 3, 224, 224}, make([]float32, 150528))
        req.AddFetch(0)
        for i := 0; i < 10; i++ {
            resp, err := client.TorchPredict(req)
            if err != nil {
                fmt.Printf("failed to predict: %v", err)
            } else {
                fmt.Println(resp.GetTensorShape(0), resp.GetFloatVal(0))
            }
        }
    }
  • 通过VPC网络直连方式调用服务的示例
    通过网络直连方式,您只能访问部署在PAI-EAS专属资源组的服务,且需要为该资源组与用户指定的vSwitch连通网络后才能使用。关于如何购买PAI-EAS专属资源组和连通网络,请参见专属资源组VPC高速直连。该调用方式与普通调用方式相比,仅需增加一行代码client.SetEndpointType(eas.EndpointTypeDirect)即可,特别适合大流量高并发的服务,具体示例如下。
    package main
    
    import (
            "fmt"
            "github.com/pai-eas/eas-golang-sdk/eas"
    )
    
    func main() {
        client := eas.NewPredictClient("pai-eas-vpc.cn-shanghai.aliyuncs.com", "scorecard_pmml_example")
        client.SetToken("YWFlMDYyZDNmNTc3M2I3MzMwYmY0MmYwM2Y2MTYxMTY4NzBkNzdj****")
        client.SetEndpointType(eas.EndpointTypeDirect)
        client.Init()
        req := "[{\"fea1\": 1, \"fea2\": 2}]"
        for i := 0; i < 100; i++ {
            resp, err := client.StringPredict(req)
            if err != nil {
                fmt.Printf("failed to predict: %v\n", err.Error())
            } else {
                fmt.Printf("%v\n", resp)
            }
        }
    }
  • 客户端连接参数设置的示例
    您可以通过http.Transport属性设置请求客户端的连接参数,示例代码如下。
    package main
    
    import (
            "fmt"
            "github.com/pai-eas/eas-golang-sdk/eas"
    )
    
    func main() {
        client := eas.NewPredictClient("pai-eas-vpc.cn-shanghai.aliyuncs.com", "network_test")
        client.SetToken("MDAwZDQ3NjE3OThhOTI4ODFmMjJiYzE0MDk1NWRkOGI1MmVhMGI0****")
        client.SetEndpointType(eas.EndpointTypeDirect)
        client.SetHttpTransport(&http.Transport{
            MaxConnsPerHost:       300,
            TLSHandshakeTimeout:   100 * time.Millisecond,
            ResponseHeaderTimeout: 200 * time.Millisecond,
            ExpectContinueTimeout: 200 * time.Millisecond,
        })
    }