549 lines
12 KiB
Go
549 lines
12 KiB
Go
package msgp
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
)
|
|
|
|
const (
|
|
// Complex64Extension is the extension number used for complex64
|
|
Complex64Extension = 3
|
|
|
|
// Complex128Extension is the extension number used for complex128
|
|
Complex128Extension = 4
|
|
|
|
// TimeExtension is the extension number used for time.Time
|
|
TimeExtension = 5
|
|
)
|
|
|
|
// our extensions live here
|
|
var extensionReg = make(map[int8]func() Extension)
|
|
|
|
// RegisterExtension registers extensions so that they
|
|
// can be initialized and returned by methods that
|
|
// decode `interface{}` values. This should only
|
|
// be called during initialization. f() should return
|
|
// a newly-initialized zero value of the extension. Keep in
|
|
// mind that extensions 3, 4, and 5 are reserved for
|
|
// complex64, complex128, and time.Time, respectively,
|
|
// and that MessagePack reserves extension types from -127 to -1.
|
|
//
|
|
// For example, if you wanted to register a user-defined struct:
|
|
//
|
|
// msgp.RegisterExtension(10, func() msgp.Extension { &MyExtension{} })
|
|
//
|
|
// RegisterExtension will panic if you call it multiple times
|
|
// with the same 'typ' argument, or if you use a reserved
|
|
// type (3, 4, or 5).
|
|
func RegisterExtension(typ int8, f func() Extension) {
|
|
switch typ {
|
|
case Complex64Extension, Complex128Extension, TimeExtension:
|
|
panic(fmt.Sprint("msgp: forbidden extension type:", typ))
|
|
}
|
|
if _, ok := extensionReg[typ]; ok {
|
|
panic(fmt.Sprint("msgp: RegisterExtension() called with typ", typ, "more than once"))
|
|
}
|
|
extensionReg[typ] = f
|
|
}
|
|
|
|
// ExtensionTypeError is an error type returned
|
|
// when there is a mis-match between an extension type
|
|
// and the type encoded on the wire
|
|
type ExtensionTypeError struct {
|
|
Got int8
|
|
Want int8
|
|
}
|
|
|
|
// Error implements the error interface
|
|
func (e ExtensionTypeError) Error() string {
|
|
return fmt.Sprintf("msgp: error decoding extension: wanted type %d; got type %d", e.Want, e.Got)
|
|
}
|
|
|
|
// Resumable returns 'true' for ExtensionTypeErrors
|
|
func (e ExtensionTypeError) Resumable() bool { return true }
|
|
|
|
func errExt(got int8, wanted int8) error {
|
|
return ExtensionTypeError{Got: got, Want: wanted}
|
|
}
|
|
|
|
// Extension is the interface fulfilled
|
|
// by types that want to define their
|
|
// own binary encoding.
|
|
type Extension interface {
|
|
// ExtensionType should return
|
|
// a int8 that identifies the concrete
|
|
// type of the extension. (Types <0 are
|
|
// officially reserved by the MessagePack
|
|
// specifications.)
|
|
ExtensionType() int8
|
|
|
|
// Len should return the length
|
|
// of the data to be encoded
|
|
Len() int
|
|
|
|
// MarshalBinaryTo should copy
|
|
// the data into the supplied slice,
|
|
// assuming that the slice has length Len()
|
|
MarshalBinaryTo([]byte) error
|
|
|
|
UnmarshalBinary([]byte) error
|
|
}
|
|
|
|
// RawExtension implements the Extension interface
|
|
type RawExtension struct {
|
|
Data []byte
|
|
Type int8
|
|
}
|
|
|
|
// ExtensionType implements Extension.ExtensionType, and returns r.Type
|
|
func (r *RawExtension) ExtensionType() int8 { return r.Type }
|
|
|
|
// Len implements Extension.Len, and returns len(r.Data)
|
|
func (r *RawExtension) Len() int { return len(r.Data) }
|
|
|
|
// MarshalBinaryTo implements Extension.MarshalBinaryTo,
|
|
// and returns a copy of r.Data
|
|
func (r *RawExtension) MarshalBinaryTo(d []byte) error {
|
|
copy(d, r.Data)
|
|
return nil
|
|
}
|
|
|
|
// UnmarshalBinary implements Extension.UnmarshalBinary,
|
|
// and sets r.Data to the contents of the provided slice
|
|
func (r *RawExtension) UnmarshalBinary(b []byte) error {
|
|
if cap(r.Data) >= len(b) {
|
|
r.Data = r.Data[0:len(b)]
|
|
} else {
|
|
r.Data = make([]byte, len(b))
|
|
}
|
|
copy(r.Data, b)
|
|
return nil
|
|
}
|
|
|
|
// WriteExtension writes an extension type to the writer
|
|
func (mw *Writer) WriteExtension(e Extension) error {
|
|
l := e.Len()
|
|
var err error
|
|
switch l {
|
|
case 0:
|
|
o, err := mw.require(3)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mw.buf[o] = mext8
|
|
mw.buf[o+1] = 0
|
|
mw.buf[o+2] = byte(e.ExtensionType())
|
|
case 1:
|
|
o, err := mw.require(2)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mw.buf[o] = mfixext1
|
|
mw.buf[o+1] = byte(e.ExtensionType())
|
|
case 2:
|
|
o, err := mw.require(2)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mw.buf[o] = mfixext2
|
|
mw.buf[o+1] = byte(e.ExtensionType())
|
|
case 4:
|
|
o, err := mw.require(2)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mw.buf[o] = mfixext4
|
|
mw.buf[o+1] = byte(e.ExtensionType())
|
|
case 8:
|
|
o, err := mw.require(2)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mw.buf[o] = mfixext8
|
|
mw.buf[o+1] = byte(e.ExtensionType())
|
|
case 16:
|
|
o, err := mw.require(2)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mw.buf[o] = mfixext16
|
|
mw.buf[o+1] = byte(e.ExtensionType())
|
|
default:
|
|
switch {
|
|
case l < math.MaxUint8:
|
|
o, err := mw.require(3)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mw.buf[o] = mext8
|
|
mw.buf[o+1] = byte(uint8(l))
|
|
mw.buf[o+2] = byte(e.ExtensionType())
|
|
case l < math.MaxUint16:
|
|
o, err := mw.require(4)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mw.buf[o] = mext16
|
|
big.PutUint16(mw.buf[o+1:], uint16(l))
|
|
mw.buf[o+3] = byte(e.ExtensionType())
|
|
default:
|
|
o, err := mw.require(6)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mw.buf[o] = mext32
|
|
big.PutUint32(mw.buf[o+1:], uint32(l))
|
|
mw.buf[o+5] = byte(e.ExtensionType())
|
|
}
|
|
}
|
|
// we can only write directly to the
|
|
// buffer if we're sure that it
|
|
// fits the object
|
|
if l <= mw.bufsize() {
|
|
o, err := mw.require(l)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return e.MarshalBinaryTo(mw.buf[o:])
|
|
}
|
|
// here we create a new buffer
|
|
// just large enough for the body
|
|
// and save it as the write buffer
|
|
err = mw.flush()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
buf := make([]byte, l)
|
|
err = e.MarshalBinaryTo(buf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mw.buf = buf
|
|
mw.wloc = l
|
|
return nil
|
|
}
|
|
|
|
// peek at the extension type, assuming the next
|
|
// kind to be read is Extension
|
|
func (m *Reader) peekExtensionType() (int8, error) {
|
|
p, err := m.R.Peek(2)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
spec := sizes[p[0]]
|
|
if spec.typ != ExtensionType {
|
|
return 0, badPrefix(ExtensionType, p[0])
|
|
}
|
|
if spec.extra == constsize {
|
|
return int8(p[1]), nil
|
|
}
|
|
size := spec.size
|
|
p, err = m.R.Peek(int(size))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int8(p[size-1]), nil
|
|
}
|
|
|
|
// peekExtension peeks at the extension encoding type
|
|
// (must guarantee at least 1 byte in 'b')
|
|
func peekExtension(b []byte) (int8, error) {
|
|
spec := sizes[b[0]]
|
|
size := spec.size
|
|
if spec.typ != ExtensionType {
|
|
return 0, badPrefix(ExtensionType, b[0])
|
|
}
|
|
if len(b) < int(size) {
|
|
return 0, ErrShortBytes
|
|
}
|
|
// for fixed extensions,
|
|
// the type information is in
|
|
// the second byte
|
|
if spec.extra == constsize {
|
|
return int8(b[1]), nil
|
|
}
|
|
// otherwise, it's in the last
|
|
// part of the prefix
|
|
return int8(b[size-1]), nil
|
|
}
|
|
|
|
// ReadExtension reads the next object from the reader
|
|
// as an extension. ReadExtension will fail if the next
|
|
// object in the stream is not an extension, or if
|
|
// e.Type() is not the same as the wire type.
|
|
func (m *Reader) ReadExtension(e Extension) (err error) {
|
|
var p []byte
|
|
p, err = m.R.Peek(2)
|
|
if err != nil {
|
|
return
|
|
}
|
|
lead := p[0]
|
|
var read int
|
|
var off int
|
|
switch lead {
|
|
case mfixext1:
|
|
if int8(p[1]) != e.ExtensionType() {
|
|
err = errExt(int8(p[1]), e.ExtensionType())
|
|
return
|
|
}
|
|
p, err = m.R.Peek(3)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = e.UnmarshalBinary(p[2:])
|
|
if err == nil {
|
|
_, err = m.R.Skip(3)
|
|
}
|
|
return
|
|
|
|
case mfixext2:
|
|
if int8(p[1]) != e.ExtensionType() {
|
|
err = errExt(int8(p[1]), e.ExtensionType())
|
|
return
|
|
}
|
|
p, err = m.R.Peek(4)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = e.UnmarshalBinary(p[2:])
|
|
if err == nil {
|
|
_, err = m.R.Skip(4)
|
|
}
|
|
return
|
|
|
|
case mfixext4:
|
|
if int8(p[1]) != e.ExtensionType() {
|
|
err = errExt(int8(p[1]), e.ExtensionType())
|
|
return
|
|
}
|
|
p, err = m.R.Peek(6)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = e.UnmarshalBinary(p[2:])
|
|
if err == nil {
|
|
_, err = m.R.Skip(6)
|
|
}
|
|
return
|
|
|
|
case mfixext8:
|
|
if int8(p[1]) != e.ExtensionType() {
|
|
err = errExt(int8(p[1]), e.ExtensionType())
|
|
return
|
|
}
|
|
p, err = m.R.Peek(10)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = e.UnmarshalBinary(p[2:])
|
|
if err == nil {
|
|
_, err = m.R.Skip(10)
|
|
}
|
|
return
|
|
|
|
case mfixext16:
|
|
if int8(p[1]) != e.ExtensionType() {
|
|
err = errExt(int8(p[1]), e.ExtensionType())
|
|
return
|
|
}
|
|
p, err = m.R.Peek(18)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = e.UnmarshalBinary(p[2:])
|
|
if err == nil {
|
|
_, err = m.R.Skip(18)
|
|
}
|
|
return
|
|
|
|
case mext8:
|
|
p, err = m.R.Peek(3)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if int8(p[2]) != e.ExtensionType() {
|
|
err = errExt(int8(p[2]), e.ExtensionType())
|
|
return
|
|
}
|
|
read = int(uint8(p[1]))
|
|
off = 3
|
|
|
|
case mext16:
|
|
p, err = m.R.Peek(4)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if int8(p[3]) != e.ExtensionType() {
|
|
err = errExt(int8(p[3]), e.ExtensionType())
|
|
return
|
|
}
|
|
read = int(big.Uint16(p[1:]))
|
|
off = 4
|
|
|
|
case mext32:
|
|
p, err = m.R.Peek(6)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if int8(p[5]) != e.ExtensionType() {
|
|
err = errExt(int8(p[5]), e.ExtensionType())
|
|
return
|
|
}
|
|
read = int(big.Uint32(p[1:]))
|
|
off = 6
|
|
|
|
default:
|
|
err = badPrefix(ExtensionType, lead)
|
|
return
|
|
}
|
|
|
|
p, err = m.R.Peek(read + off)
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = e.UnmarshalBinary(p[off:])
|
|
if err == nil {
|
|
_, err = m.R.Skip(read + off)
|
|
}
|
|
return
|
|
}
|
|
|
|
// AppendExtension appends a MessagePack extension to the provided slice
|
|
func AppendExtension(b []byte, e Extension) ([]byte, error) {
|
|
l := e.Len()
|
|
var o []byte
|
|
var n int
|
|
switch l {
|
|
case 0:
|
|
o, n = ensure(b, 3)
|
|
o[n] = mext8
|
|
o[n+1] = 0
|
|
o[n+2] = byte(e.ExtensionType())
|
|
return o[:n+3], nil
|
|
case 1:
|
|
o, n = ensure(b, 3)
|
|
o[n] = mfixext1
|
|
o[n+1] = byte(e.ExtensionType())
|
|
n += 2
|
|
case 2:
|
|
o, n = ensure(b, 4)
|
|
o[n] = mfixext2
|
|
o[n+1] = byte(e.ExtensionType())
|
|
n += 2
|
|
case 4:
|
|
o, n = ensure(b, 6)
|
|
o[n] = mfixext4
|
|
o[n+1] = byte(e.ExtensionType())
|
|
n += 2
|
|
case 8:
|
|
o, n = ensure(b, 10)
|
|
o[n] = mfixext8
|
|
o[n+1] = byte(e.ExtensionType())
|
|
n += 2
|
|
case 16:
|
|
o, n = ensure(b, 18)
|
|
o[n] = mfixext16
|
|
o[n+1] = byte(e.ExtensionType())
|
|
n += 2
|
|
default:
|
|
switch {
|
|
case l < math.MaxUint8:
|
|
o, n = ensure(b, l+3)
|
|
o[n] = mext8
|
|
o[n+1] = byte(uint8(l))
|
|
o[n+2] = byte(e.ExtensionType())
|
|
n += 3
|
|
case l < math.MaxUint16:
|
|
o, n = ensure(b, l+4)
|
|
o[n] = mext16
|
|
big.PutUint16(o[n+1:], uint16(l))
|
|
o[n+3] = byte(e.ExtensionType())
|
|
n += 4
|
|
default:
|
|
o, n = ensure(b, l+6)
|
|
o[n] = mext32
|
|
big.PutUint32(o[n+1:], uint32(l))
|
|
o[n+5] = byte(e.ExtensionType())
|
|
n += 6
|
|
}
|
|
}
|
|
return o, e.MarshalBinaryTo(o[n:])
|
|
}
|
|
|
|
// ReadExtensionBytes reads an extension from 'b' into 'e'
|
|
// and returns any remaining bytes.
|
|
// Possible errors:
|
|
// - ErrShortBytes ('b' not long enough)
|
|
// - ExtensionTypeErorr{} (wire type not the same as e.Type())
|
|
// - TypeErorr{} (next object not an extension)
|
|
// - InvalidPrefixError
|
|
// - An umarshal error returned from e.UnmarshalBinary
|
|
func ReadExtensionBytes(b []byte, e Extension) ([]byte, error) {
|
|
l := len(b)
|
|
if l < 3 {
|
|
return b, ErrShortBytes
|
|
}
|
|
lead := b[0]
|
|
var (
|
|
sz int // size of 'data'
|
|
off int // offset of 'data'
|
|
typ int8
|
|
)
|
|
switch lead {
|
|
case mfixext1:
|
|
typ = int8(b[1])
|
|
sz = 1
|
|
off = 2
|
|
case mfixext2:
|
|
typ = int8(b[1])
|
|
sz = 2
|
|
off = 2
|
|
case mfixext4:
|
|
typ = int8(b[1])
|
|
sz = 4
|
|
off = 2
|
|
case mfixext8:
|
|
typ = int8(b[1])
|
|
sz = 8
|
|
off = 2
|
|
case mfixext16:
|
|
typ = int8(b[1])
|
|
sz = 16
|
|
off = 2
|
|
case mext8:
|
|
sz = int(uint8(b[1]))
|
|
typ = int8(b[2])
|
|
off = 3
|
|
if sz == 0 {
|
|
return b[3:], e.UnmarshalBinary(b[3:3])
|
|
}
|
|
case mext16:
|
|
if l < 4 {
|
|
return b, ErrShortBytes
|
|
}
|
|
sz = int(big.Uint16(b[1:]))
|
|
typ = int8(b[3])
|
|
off = 4
|
|
case mext32:
|
|
if l < 6 {
|
|
return b, ErrShortBytes
|
|
}
|
|
sz = int(big.Uint32(b[1:]))
|
|
typ = int8(b[5])
|
|
off = 6
|
|
default:
|
|
return b, badPrefix(ExtensionType, lead)
|
|
}
|
|
|
|
if typ != e.ExtensionType() {
|
|
return b, errExt(typ, e.ExtensionType())
|
|
}
|
|
|
|
// the data of the extension starts
|
|
// at 'off' and is 'sz' bytes long
|
|
if len(b[off:]) < sz {
|
|
return b, ErrShortBytes
|
|
}
|
|
tot := off + sz
|
|
return b[tot:], e.UnmarshalBinary(b[off:tot])
|
|
}
|