From bc595dd02410fc3021c4d8a0ca412ec3396bf9bc Mon Sep 17 00:00:00 2001 From: Nikita Tokarchuk Date: Wed, 26 Dec 2018 23:14:48 +0100 Subject: [PATCH] Export getId into function --- mongox/base/getid.go | 45 ++++++++++++++++++++++++++++++++++++++++ mongox/common/saveone.go | 18 +--------------- 2 files changed, 46 insertions(+), 17 deletions(-) create mode 100644 mongox/base/getid.go diff --git a/mongox/base/getid.go b/mongox/base/getid.go new file mode 100644 index 0000000..2a9dff1 --- /dev/null +++ b/mongox/base/getid.go @@ -0,0 +1,45 @@ +package base + +import ( + "github.com/mainnika/mongox-go-driver/mongox" + "github.com/mainnika/mongox-go-driver/mongox/errors" + "github.com/mongodb/mongo-go-driver/bson/primitive" +) + +// GetID returns source document id +func GetID(source interface{}) (id interface{}) { + + switch doc := source.(type) { + case mongox.BaseObjectID: + return getObjectIdOrGenerate(doc) + case mongox.BaseString: + return getStringIdOrPanic(doc) + default: + panic(errors.Malformedf("source contains malformed document, %v", source)) + } + + return +} + +func getObjectIdOrGenerate(source mongox.BaseObjectID) (id primitive.ObjectID) { + + id = source.GetID() + if id != primitive.NilObjectID { + return id + } + + id = primitive.NewObjectID() + source.SetID(id) + + return +} + +func getStringIdOrPanic(source mongox.BaseString) (id string) { + + id = source.GetID() + if id != "" { + return id + } + + panic(errors.Malformedf("victim contains malformed document, %v", source)) +} diff --git a/mongox/common/saveone.go b/mongox/common/saveone.go index 2200579..5af92be 100644 --- a/mongox/common/saveone.go +++ b/mongox/common/saveone.go @@ -13,27 +13,11 @@ func SaveOne(db *mongox.Database, source interface{}) error { collection := db.GetCollectionOf(source) opts := &options.FindOneAndReplaceOptions{} + id := base.GetID(source) opts.SetUpsert(true) opts.SetReturnDocument(options.After) - var id interface{} - - switch doc := source.(type) { - case mongox.BaseObjectID: - id = doc.GetID() - if id == primitive.NilObjectID { - id = primitive.NewObjectID() - } - case mongox.BaseString: - id = doc.GetID() - if id == "" { - panic(errors.Malformedf("source contains malformed document, %v", source)) - } - default: - panic(errors.Malformedf("source contains malformed document, %v", source)) - } - result := collection.FindOneAndReplace(db.Context(), bson.M{"_id": id}, source, opts) if result.Err() != nil { return errors.NotFoundErrorf("%s", result.Err())