peridot/vendor/alexejk.io/go-xmlrpc/encode.go

234 lines
5.8 KiB
Go

package xmlrpc
import (
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"reflect"
"time"
)
// Encoder implementations are responsible for handling encoding of XML-RPC requests to the proper wire format.
type Encoder interface {
Encode(w io.Writer, methodName string, args interface{}) error
}
// StdEncoder is the default implementation of Encoder interface.
type StdEncoder struct{}
func (e *StdEncoder) Encode(w io.Writer, methodName string, args interface{}) error {
_, _ = fmt.Fprintf(w, "<methodCall><methodName>%s</methodName>", methodName)
if args != nil {
if err := e.encodeArgs(w, args); err != nil {
return fmt.Errorf("cannot encoded provided method arguments: %w", err)
}
}
_, _ = fmt.Fprint(w, "</methodCall>")
return nil
}
func (e *StdEncoder) encodeArgs(w io.Writer, args interface{}) error {
// Allows reading both pointer and value-structs
elem := reflect.Indirect(reflect.ValueOf(args))
numFields := elem.NumField()
if numFields > 0 {
hasExportedFields := false
for fN := 0; fN < numFields; fN++ {
field := elem.Field(fN)
if !field.CanInterface() {
continue
}
// If this is first exported field - print out <params> tag
if !hasExportedFields {
hasExportedFields = true
_, _ = fmt.Fprint(w, "<params>")
}
_, _ = fmt.Fprint(w, "<param>")
if err := e.encodeValue(w, field.Interface()); err != nil {
return fmt.Errorf("cannot encode argument '%s': %w", elem.Type().Field(fN).Name, err)
}
_, _ = fmt.Fprint(w, "</param>")
}
// Only write closing </params> tag if at least one exported field is found
if hasExportedFields {
_, _ = fmt.Fprint(w, "</params>")
}
}
return nil
}
// encodeValue will encode input into the XML-RPC compatible format.
// If provided value is a pointer, value of pointer will be used, unless pointer is nil.
// In that case a <nil/> value is returned.
//
// See more: https://en.wikipedia.org/wiki/XML-RPC#Data_types
func (e *StdEncoder) encodeValue(w io.Writer, value interface{}) error {
valueOf := reflect.ValueOf(value)
kind := valueOf.Kind()
// Handling pointers by following them.
if kind == reflect.Ptr {
if valueOf.IsNil() {
_, _ = fmt.Fprint(w, "<value><nil/></value>")
return nil
}
return e.encodeValue(w, valueOf.Elem().Interface())
}
_, _ = fmt.Fprint(w, "<value>")
switch kind {
case reflect.Bool:
if err := e.encodeBoolean(w, value.(bool)); err != nil {
return fmt.Errorf("cannot encode boolean value: %w", err)
}
case reflect.Int:
if err := e.encodeInteger(w, value.(int)); err != nil {
return fmt.Errorf("cannot encode integer value: %w", err)
}
case reflect.Float64:
if err := e.encodeDouble(w, value.(float64)); err != nil {
return fmt.Errorf("cannot encode double value: %w", err)
}
case reflect.String:
if err := e.encodeString(w, value.(string)); err != nil {
return fmt.Errorf("cannot encode string value: %w", err)
}
case reflect.Array, reflect.Slice:
if e.isByteArray(value) {
if err := e.encodeBase64(w, value.([]byte)); err != nil {
return fmt.Errorf("cannot encode byte-array value: %w", err)
}
} else {
if err := e.encodeArray(w, value); err != nil {
return fmt.Errorf("cannot encode array value: %w", err)
}
}
case reflect.Struct:
if reflect.TypeOf(value).String() != "time.Time" {
if err := e.encodeStruct(w, value); err != nil {
return fmt.Errorf("cannot encode struct value: %w", err)
}
} else {
if err := e.encodeTime(w, value.(time.Time)); err != nil {
return fmt.Errorf("cannot encode time.Time value: %w", err)
}
}
}
_, _ = fmt.Fprint(w, "</value>")
return nil
}
func (e *StdEncoder) isByteArray(val interface{}) bool {
_, ok := val.([]byte)
return ok
}
func (e *StdEncoder) encodeInteger(w io.Writer, val int) error {
_, err := fmt.Fprintf(w, "<int>%d</int>", val)
return err
}
func (e *StdEncoder) encodeDouble(w io.Writer, val float64) error {
_, err := fmt.Fprintf(w, "<double>%f</double>", val)
return err
}
func (e *StdEncoder) encodeBoolean(w io.Writer, val bool) error {
v := 0
if val {
v = 1
}
_, err := fmt.Fprintf(w, "<boolean>%d</boolean>", v)
return err
}
func (e *StdEncoder) encodeString(w io.Writer, val string) error {
_, _ = fmt.Fprint(w, "<string>")
if err := xml.EscapeText(w, []byte(val)); err != nil {
return fmt.Errorf("failed to escape string: %w", err)
}
_, _ = fmt.Fprint(w, "</string>")
return nil
}
func (e *StdEncoder) encodeArray(w io.Writer, val interface{}) error {
_, _ = fmt.Fprint(w, "<array><data>")
for i := 0; i < reflect.ValueOf(val).Len(); i++ {
if err := e.encodeValue(w, reflect.ValueOf(val).Index(i).Interface()); err != nil {
return fmt.Errorf("cannot encode array element at index %d: %w", i, err)
}
}
_, _ = fmt.Fprint(w, "</data></array>")
return nil
}
func (e *StdEncoder) encodeStruct(w io.Writer, val interface{}) error {
_, _ = fmt.Fprint(w, "<struct>")
for i := 0; i < reflect.TypeOf(val).NumField(); i++ {
field := reflect.ValueOf(val).Field(i)
// Skip over unexported fields
if !field.CanInterface() {
continue
}
fieldType := reflect.TypeOf(val).Field(i)
fieldName := fieldType.Name
if fieldType.Tag.Get("xml") != "" {
fieldName = fieldType.Tag.Get("xml")
}
_, _ = fmt.Fprintf(w, "<member><name>%s</name>", fieldName)
if err := e.encodeValue(w, field.Interface()); err != nil {
return fmt.Errorf("cannot encode value of struct field '%s': %w", fieldName, err)
}
_, _ = fmt.Fprint(w, "</member>")
}
_, _ = fmt.Fprint(w, "</struct>")
return nil
}
func (e *StdEncoder) encodeBase64(w io.Writer, val []byte) error {
_, err := fmt.Fprintf(w, "<base64>%s</base64>", base64.StdEncoding.EncodeToString(val))
return err
}
func (e *StdEncoder) encodeTime(w io.Writer, val time.Time) error {
_, err := fmt.Fprintf(w, "<dateTime.iso8601>%s</dateTime.iso8601>", val.Format(time.RFC3339))
return err
}