diff --git a/mongox/base/getid.go b/mongox/base/getid.go index 3515ca6..1e31767 100644 --- a/mongox/base/getid.go +++ b/mongox/base/getid.go @@ -18,6 +18,9 @@ func GetID(source interface{}) (id interface{}) { return getStringIDOrPanic(doc) case mongox.ObjectBased: return getObjectOrPanic(doc) + case mongox.InterfaceBased: + return getInterfaceOrPanic(doc) + default: panic(fmt.Errorf("source contains malformed document, %v", source)) } @@ -55,3 +58,13 @@ func getObjectOrPanic(source mongox.ObjectBased) (id primitive.D) { panic(fmt.Errorf("source contains malformed document, %v", source)) } + +func getInterfaceOrPanic(source mongox.InterfaceBased) (id interface{}) { + + id = source.GetID() + if id != nil { + return id + } + + panic(fmt.Errorf("source contains malformed document, %v", source)) +} diff --git a/mongox/base/getid_test.go b/mongox/base/getid_test.go index 2cb8ee9..bf62fad 100644 --- a/mongox/base/getid_test.go +++ b/mongox/base/getid_test.go @@ -6,6 +6,18 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" ) +type DocWithCustomInterface struct { + ID int `bson:"_id" json:"_id" collection:"4"` +} + +func (d *DocWithCustomInterface) GetID() interface{} { + return d.ID +} + +func (d *DocWithCustomInterface) SetID(id interface{}) { + panic("not implemented") +} + func TestGetID(t *testing.T) { type DocWithObjectID struct { @@ -21,4 +33,5 @@ func TestGetID(t *testing.T) { GetID(&DocWithObjectID{ObjectID: ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2})}) GetID(&DocWithObject{Object: Object(primitive.D{{"1", "2"}})}) GetID(&DocWithString{String: String("foobar")}) + GetID(&DocWithCustomInterface{ID: 420}) } diff --git a/mongox/mongox.go b/mongox/mongox.go index de9452a..422002d 100644 --- a/mongox/mongox.go +++ b/mongox/mongox.go @@ -97,3 +97,8 @@ type ObjectBased interface { GetID() primitive.D SetID(id primitive.D) } + +type InterfaceBased interface { + GetID() interface{} + SetID(id interface{}) +}