store reflect.Type when registering Record

This commit is contained in:
Yusef Napora 2020-01-17 10:48:37 -05:00
parent 972454490a
commit 7ee4611788

View File

@ -10,7 +10,7 @@ var (
// PayloadType does not match any registered Record types. // PayloadType does not match any registered Record types.
ErrPayloadTypeNotRegistered = errors.New("payload type is not registered") ErrPayloadTypeNotRegistered = errors.New("payload type is not registered")
payloadTypeRegistry = make(map[string]Record) payloadTypeRegistry = make(map[string]reflect.Type)
) )
// Record represents a data type that can be used as the payload of an Envelope. // Record represents a data type that can be used as the payload of an Envelope.
@ -52,7 +52,7 @@ type Record interface {
// type HelloRecord struct { } // etc.. // type HelloRecord struct { } // etc..
// //
func RegisterPayloadType(payloadType []byte, prototype Record) { func RegisterPayloadType(payloadType []byte, prototype Record) {
payloadTypeRegistry[string(payloadType)] = prototype payloadTypeRegistry[string(payloadType)] = getValueType(prototype)
} }
func unmarshalRecordPayload(payloadType []byte, payloadBytes []byte) (Record, error) { func unmarshalRecordPayload(payloadType []byte, payloadBytes []byte) (Record, error) {
@ -68,12 +68,11 @@ func unmarshalRecordPayload(payloadType []byte, payloadBytes []byte) (Record, er
} }
func blankRecordForPayloadType(payloadType []byte) (Record, error) { func blankRecordForPayloadType(payloadType []byte) (Record, error) {
prototype, ok := payloadTypeRegistry[string(payloadType)] valueType, ok := payloadTypeRegistry[string(payloadType)]
if !ok { if !ok {
return nil, ErrPayloadTypeNotRegistered return nil, ErrPayloadTypeNotRegistered
} }
valueType := getValueType(prototype)
val := reflect.New(valueType) val := reflect.New(valueType)
asRecord := val.Interface().(Record) asRecord := val.Interface().(Record)
return asRecord, nil return asRecord, nil
@ -82,8 +81,7 @@ func blankRecordForPayloadType(payloadType []byte) (Record, error) {
func payloadTypeForRecord(rec Record) ([]byte, bool) { func payloadTypeForRecord(rec Record) ([]byte, bool) {
valueType := getValueType(rec) valueType := getValueType(rec)
for k, v := range payloadTypeRegistry { for k, t := range payloadTypeRegistry {
t := getValueType(v)
if t.AssignableTo(valueType) { if t.AssignableTo(valueType) {
return []byte(k), true return []byte(k), true
} }