diff --git a/go.mod b/go.mod index 4502cf4..63ef406 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.13 require ( github.com/modern-go/reflect2 v1.0.2 - github.com/stretchr/testify v1.7.0 + github.com/stretchr/testify v1.8.4 github.com/valyala/bytebufferpool v1.0.0 - go.mongodb.org/mongo-driver v1.8.2 + go.mongodb.org/mongo-driver v1.11.6 ) diff --git a/go.sum b/go.sum index 4d69013..1543141 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= @@ -16,50 +14,53 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/scram v1.0.2 h1:akYIkZ28e6A96dkWNJQu3nmCzH3YfwMPQExUYDaRv7w= -github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= -github.com/xdg-go/stringprep v1.0.2 h1:6iq84/ryjjeRmMJwxutI51F2GIPlP5BfTvXHeYjyhBc= -github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= +github.com/xdg-go/scram v1.1.1 h1:VOMT+81stJgXW3CpHyqHN3AXDYIMsx56mEFrB37Mb/E= +github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= +github.com/xdg-go/stringprep v1.0.3 h1:kdwGpVNwPFtjs98xCGkHjQtGKh86rDcRZN17QEMCOIs= +github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= -go.mongodb.org/mongo-driver v1.8.2 h1:8ssUXufb90ujcIvR6MyE1SchaNj0SFxsakiZgxIyrMk= -go.mongodb.org/mongo-driver v1.8.2/go.mod h1:0sQWfOeY63QTntERDJJ/0SuKK0T1uVSgKCuAROlKEPY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f h1:aZp0e2vLN4MToVqnjNEYEtrEA8RH8U8FN1CU7JgqsPU= -golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +go.mongodb.org/mongo-driver v1.11.6 h1:XM7G6PjiGAO5betLF13BIa5TlLUUE3uJ/2Ox3Lz1K+o= +go.mongodb.org/mongo-driver v1.11.6/go.mod h1:G9TgswdsWjX4tmDA5zfs2+6AEPpYJwqblyjsfuh8oXY= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mongox-testing/database/ephemeral.go b/mongox-testing/database/ephemeral.go index 299d82a..1786126 100644 --- a/mongox-testing/database/ephemeral.go +++ b/mongox-testing/database/ephemeral.go @@ -29,18 +29,24 @@ func init() { // NewEphemeral creates new mongo connection func NewEphemeral(URI string) (db *EphemeralDatabase, err error) { + return NewEphemeralWithContext(context.Background(), URI) +} +func NewEphemeralWithContext(ctx context.Context, URI string) (db *EphemeralDatabase, err error) { if URI == "" { URI = defaultURI } name := primitive.NewObjectID().Hex() opts := options.Client().ApplyURI(URI) - client, err := mongo.Connect(context.Background(), opts) + client, err := mongo.Connect(ctx, opts) + if err != nil { + return nil, err + } - db = &EphemeralDatabase{Database: database.NewDatabase(client, name)} + db = &EphemeralDatabase{Database: database.NewDatabase(ctx, client, name)} - return + return db, nil } // Close the connection and drop database diff --git a/mongox/base/docbased/id.go b/mongox/base/docbased/id.go new file mode 100644 index 0000000..1d41efa --- /dev/null +++ b/mongox/base/docbased/id.go @@ -0,0 +1,44 @@ +package docbased + +import ( + "github.com/modern-go/reflect2" + "go.mongodb.org/mongo-driver/bson/primitive" + + "github.com/mainnika/mongox-go-driver/v2/mongox" +) + +var _ mongox.DocBased = (*Primary)(nil) + +// Primary is a structure with object as an _id field +type Primary struct { + ID primitive.D `bson:"_id" json:"_id"` +} + +// GetID returns an _id +func (p *Primary) GetID() (id primitive.D) { + return p.ID +} + +// SetID sets an _id +func (p *Primary) SetID(id primitive.D) { + p.ID = id +} + +// New creates a new Primary structure with a defined _id +func New(e primitive.E, ee ...primitive.E) Primary { + id := primitive.D{e} + if len(ee) > 0 { + id = append(id, ee...) + } + + return Primary{ID: id} +} + +func GetID(source mongox.DocBased) (id primitive.D, err error) { + id = source.GetID() + if !reflect2.IsNil(id) { + return id, nil + } + + return nil, mongox.ErrUninitializedBase +} diff --git a/mongox/base/jsonbased/id_test.go b/mongox/base/docbased/id_test.go similarity index 61% rename from mongox/base/jsonbased/id_test.go rename to mongox/base/docbased/id_test.go index 41e2ce6..617c5a2 100644 --- a/mongox/base/jsonbased/id_test.go +++ b/mongox/base/docbased/id_test.go @@ -1,4 +1,4 @@ -package jsonbased_test +package docbased_test import ( "encoding/json" @@ -8,27 +8,25 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" "github.com/mainnika/mongox-go-driver/v2/mongox-testing/database" - "github.com/mainnika/mongox-go-driver/v2/mongox/base/jsonbased" + "github.com/mainnika/mongox-go-driver/v2/mongox/base/docbased" ) func Test_GetID(t *testing.T) { - type DocWithObject struct { - jsonbased.Primary `bson:",inline" json:",inline" collection:"1"` + docbased.Primary `bson:",inline" json:",inline" collection:"1"` } - doc := &DocWithObject{Primary: jsonbased.Primary{ID: primitive.D{{"1", "one"}, {"2", "two"}}}} + doc := &DocWithObject{Primary: docbased.New(primitive.E{"1", "one"}, primitive.E{"2", "two"})} assert.Equal(t, primitive.D{{"1", "one"}, {"2", "two"}}, doc.GetID()) } func Test_SetID(t *testing.T) { - type DocWithObject struct { - jsonbased.Primary `bson:",inline" json:",inline" collection:"1"` + docbased.Primary `bson:",inline" json:",inline" collection:"1"` } - doc := &DocWithObject{Primary: jsonbased.Primary{ID: primitive.D{{"1", "one"}, {"2", "two"}}}} + doc := &DocWithObject{Primary: docbased.New(primitive.E{"1", "one"}, primitive.E{"2", "two"})} doc.SetID(primitive.D{{"3", "three"}, {"4", "you"}}) @@ -37,9 +35,8 @@ func Test_SetID(t *testing.T) { } func Test_SaveLoad(t *testing.T) { - type DocWithObjectID struct { - jsonbased.Primary `bson:",inline" json:",inline" collection:"1"` + docbased.Primary `bson:",inline" json:",inline" collection:"1"` } db, err := database.NewEphemeral("") @@ -47,9 +44,9 @@ func Test_SaveLoad(t *testing.T) { t.Fatal(err) } - defer db.Close() + defer func() { _ = db.Close() }() - doc1 := &DocWithObjectID{Primary: jsonbased.Primary{ID: primitive.D{{"1", "one"}, {"2", "two"}}}} + doc1 := &DocWithObjectID{Primary: docbased.New(primitive.E{"1", "one"}, primitive.E{"2", "two"})} doc2 := &DocWithObjectID{} err = db.SaveOne(doc1) @@ -67,12 +64,11 @@ func Test_SaveLoad(t *testing.T) { } func Test_Marshal(t *testing.T) { - type DocWithObjectID struct { - jsonbased.Primary `bson:",inline" json:",inline" collection:"1"` + docbased.Primary `bson:",inline" json:",inline" collection:"1"` } - doc := &DocWithObjectID{Primary: jsonbased.Primary{ID: primitive.D{{"1", "one"}, {"2", "two"}}}} + doc := &DocWithObjectID{Primary: docbased.New(primitive.E{"1", "one"}, primitive.E{"2", "two"})} bytes, err := json.Marshal(doc) assert.NoError(t, err) diff --git a/mongox/base/getid.go b/mongox/base/getid.go index b916bbd..0daf86f 100644 --- a/mongox/base/getid.go +++ b/mongox/base/getid.go @@ -2,70 +2,26 @@ package base import ( "fmt" - - "github.com/modern-go/reflect2" - "go.mongodb.org/mongo-driver/bson/primitive" + "github.com/mainnika/mongox-go-driver/v2/mongox/base/docbased" + "github.com/mainnika/mongox-go-driver/v2/mongox/base/ifacebased" + "github.com/mainnika/mongox-go-driver/v2/mongox/base/oidbased" + "github.com/mainnika/mongox-go-driver/v2/mongox/base/stringbased" "github.com/mainnika/mongox-go-driver/v2/mongox" ) // GetID returns source document id -func GetID(source interface{}) (id interface{}) { - +func GetID(source interface{}) (id interface{}, err error) { switch doc := source.(type) { case mongox.OIDBased: - return getObjectIDOrGenerate(doc) + return oidbased.GetID(doc) case mongox.StringBased: - return getStringIDOrPanic(doc) - case mongox.JSONBased: - return getObjectOrPanic(doc) + return stringbased.GetID(doc) + case mongox.DocBased: + return docbased.GetID(doc) case mongox.InterfaceBased: - return getInterfaceOrPanic(doc) - + return ifacebased.GetID(doc) default: - panic(fmt.Errorf("source contains malformed document, %v", source)) - } -} - -func getObjectIDOrGenerate(source mongox.OIDBased) (id primitive.ObjectID) { - - id = source.GetID() - if id != primitive.NilObjectID { - return id - } - - id = primitive.NewObjectID() - source.SetID(id) - - return -} - -func getStringIDOrPanic(source mongox.StringBased) (id string) { - - id = source.GetID() - if id != "" { - return id - } - - panic(fmt.Errorf("source contains malformed document, %v", source)) -} - -func getObjectOrPanic(source mongox.JSONBased) (id primitive.D) { - - id = source.GetID() - if id != nil { - return id + return nil, fmt.Errorf("%w: unknown base type", mongox.ErrMalformedBase) } - - panic(fmt.Errorf("source contains malformed document, %v", source)) -} - -func getInterfaceOrPanic(source mongox.InterfaceBased) (id interface{}) { - - id = source.GetID() - if !reflect2.IsNil(id) { - return id - } - - panic(fmt.Errorf("source contains malformed document, %v", source)) } diff --git a/mongox/base/getid_test.go b/mongox/base/getid_test.go index d8d8c92..f4d5765 100644 --- a/mongox/base/getid_test.go +++ b/mongox/base/getid_test.go @@ -1,13 +1,14 @@ package base_test import ( + "github.com/mainnika/mongox-go-driver/v2/mongox/base/docbased" + "github.com/stretchr/testify/require" "testing" "github.com/stretchr/testify/assert" "go.mongodb.org/mongo-driver/bson/primitive" "github.com/mainnika/mongox-go-driver/v2/mongox/base" - "github.com/mainnika/mongox-go-driver/v2/mongox/base/jsonbased" "github.com/mainnika/mongox-go-driver/v2/mongox/base/oidbased" "github.com/mainnika/mongox-go-driver/v2/mongox/base/stringbased" ) @@ -29,15 +30,24 @@ func TestGetID(t *testing.T) { type DocWithObjectID struct { oidbased.Primary `bson:",inline" json:",inline" collection:"1"` } + id, err := base.GetID(&DocWithObjectID{Primary: oidbased.Primary{ID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}}}) + require.NoError(t, err) + assert.Equal(t, primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}), id) + type DocWithObject struct { - jsonbased.Primary `bson:",inline" json:",inline" collection:"2"` + docbased.Primary `bson:",inline" json:",inline" collection:"2"` } + id, err = base.GetID(&DocWithObject{Primary: docbased.Primary{ID: primitive.D{{"1", "2"}}}}) + require.NoError(t, err) + assert.Equal(t, primitive.D{{"1", "2"}}, id) + type DocWithString struct { stringbased.Primary `bson:",inline" json:",inline" collection:"3"` } + id, err = base.GetID(&DocWithString{Primary: stringbased.Primary{ID: "foobar"}}) + assert.Equal(t, "foobar", id) - assert.Equal(t, primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}), base.GetID(&DocWithObjectID{Primary: oidbased.Primary{ID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}}})) - assert.Equal(t, primitive.D{{"1", "2"}}, base.GetID(&DocWithObject{Primary: jsonbased.Primary{ID: primitive.D{{"1", "2"}}}})) - assert.Equal(t, "foobar", base.GetID(&DocWithString{Primary: stringbased.Primary{ID: "foobar"}})) - assert.Equal(t, 420, base.GetID(&DocWithCustomInterface{ID: 420})) + id, err = base.GetID(&DocWithCustomInterface{ID: 420}) + assert.NoError(t, err) + assert.Equal(t, 420, id) } diff --git a/mongox/base/getprotection.go b/mongox/base/getprotection.go deleted file mode 100644 index 3076ef7..0000000 --- a/mongox/base/getprotection.go +++ /dev/null @@ -1,40 +0,0 @@ -package base - -import ( - "reflect" - - "github.com/mainnika/mongox-go-driver/v2/mongox/base/protection" -) - -// GetProtection function finds protection field in the source document otherwise returns nil -func GetProtection(source interface{}) (key *protection.Key) { - - v := reflect.ValueOf(source) - if v.Kind() != reflect.Ptr || v.IsNil() { - return - } - - el := v.Elem() - numField := el.NumField() - - for i := 0; i < numField; i++ { - field := el.Field(i) - if !field.CanInterface() { - continue - } - - switch field.Interface().(type) { - case *protection.Key: - key = field.Interface().(*protection.Key) - case protection.Key: - ptr := field.Addr() - key = ptr.Interface().(*protection.Key) - default: - continue - } - - return - } - - return -} diff --git a/mongox/base/ifacebased/id.go b/mongox/base/ifacebased/id.go new file mode 100644 index 0000000..3f962eb --- /dev/null +++ b/mongox/base/ifacebased/id.go @@ -0,0 +1,16 @@ +package ifacebased + +import ( + "github.com/mainnika/mongox-go-driver/v2/mongox" + "github.com/modern-go/reflect2" +) + +// GetID returns an _id from the source document +func GetID(source mongox.InterfaceBased) (id interface{}, err error) { + id = source.GetID() + if !reflect2.IsNil(id) { + return id, nil + } + + return nil, mongox.ErrUninitializedBase +} diff --git a/mongox/base/jsonbased/id.go b/mongox/base/jsonbased/id.go deleted file mode 100644 index 730fed5..0000000 --- a/mongox/base/jsonbased/id.go +++ /dev/null @@ -1,24 +0,0 @@ -package jsonbased - -import ( - "go.mongodb.org/mongo-driver/bson/primitive" - - "github.com/mainnika/mongox-go-driver/v2/mongox" -) - -var _ mongox.JSONBased = (*Primary)(nil) - -// Primary is a structure with object as an _id field -type Primary struct { - ID primitive.D `bson:"_id" json:"_id"` -} - -// GetID returns an _id -func (p *Primary) GetID() (id primitive.D) { - return p.ID -} - -// SetID sets an _id -func (p *Primary) SetID(id primitive.D) { - p.ID = id -} diff --git a/mongox/base/oidbased/id.go b/mongox/base/oidbased/id.go index 3982047..ecc0df5 100644 --- a/mongox/base/oidbased/id.go +++ b/mongox/base/oidbased/id.go @@ -22,3 +22,22 @@ func (p *Primary) GetID() (id primitive.ObjectID) { func (p *Primary) SetID(id primitive.ObjectID) { p.ID = id } + +// Generate creates a new Primary structure with a new objectId +func Generate() Primary { + return Primary{ID: primitive.NewObjectID()} +} + +// New creates a new Primary structure with a defined objectId +func New(id primitive.ObjectID) Primary { + return Primary{ID: id} +} + +func GetID(source mongox.OIDBased) (id primitive.ObjectID, err error) { + id = source.GetID() + if id != primitive.NilObjectID { + return id, nil + } + + return primitive.NilObjectID, mongox.ErrUninitializedBase +} diff --git a/mongox/base/oidbased/id_test.go b/mongox/base/oidbased/id_test.go index 8633ac3..f6b1de8 100644 --- a/mongox/base/oidbased/id_test.go +++ b/mongox/base/oidbased/id_test.go @@ -12,24 +12,22 @@ import ( ) func Test_GetID(t *testing.T) { - type DocWithObjectID struct { oidbased.Primary `bson:",inline" json:",inline" collection:"1"` } - doc := &DocWithObjectID{Primary: oidbased.Primary{ID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}}} + doc := &DocWithObjectID{Primary: oidbased.New([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2})} + assert.Equal(t, primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}), doc.Primary.ID) assert.Equal(t, primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}), doc.GetID()) } func Test_SetID(t *testing.T) { - type DocWithObjectID struct { oidbased.Primary `bson:",inline" json:",inline" collection:"1"` } - doc := &DocWithObjectID{} - + doc := &DocWithObjectID{Primary: oidbased.Generate()} doc.SetID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}) assert.Equal(t, primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}), doc.Primary.ID) @@ -37,7 +35,6 @@ func Test_SetID(t *testing.T) { } func Test_SaveLoad(t *testing.T) { - type DocWithObjectID struct { oidbased.Primary `bson:",inline" json:",inline" collection:"1"` } @@ -47,10 +44,10 @@ func Test_SaveLoad(t *testing.T) { t.Fatal(err) } - defer db.Close() + defer func() { _ = db.Close() }() - doc1 := &DocWithObjectID{} - doc2 := &DocWithObjectID{} + doc1 := &DocWithObjectID{Primary: oidbased.Generate()} + doc2 := &DocWithObjectID{Primary: oidbased.Generate()} err = db.SaveOne(doc1) assert.NoError(t, err) @@ -67,13 +64,12 @@ func Test_SaveLoad(t *testing.T) { } func Test_Marshal(t *testing.T) { - type DocWithObjectID struct { oidbased.Primary `bson:",inline" json:",inline" collection:"1"` } id, _ := primitive.ObjectIDFromHex("feadbeeffeadbeeffeadbeef") - doc := &DocWithObjectID{Primary: oidbased.Primary{ID: id}} + doc := &DocWithObjectID{Primary: oidbased.New(id)} bytes, err := json.Marshal(doc) assert.NoError(t, err) diff --git a/mongox/base/protection/key.go b/mongox/base/protection/key.go index b64f186..9c7b350 100644 --- a/mongox/base/protection/key.go +++ b/mongox/base/protection/key.go @@ -1,6 +1,7 @@ package protection import ( + "reflect" "time" "github.com/modern-go/reflect2" @@ -13,9 +14,8 @@ type Key struct { V int64 `bson:"_v" json:"_v"` } -// PutToDocument extends the doc with protection key values -func (k *Key) PutToDocument(doc primitive.M) { - +// Inject extends the doc with protection key values +func (k *Key) Inject(doc primitive.M) { if reflect2.IsNil(doc) { return } @@ -34,3 +34,35 @@ func (k *Key) Restate() { k.X = primitive.NewObjectID() k.V = time.Now().Unix() } + +// Get finds protection field in the source document otherwise returns nil +func Get(source interface{}) (key *Key) { + v := reflect.ValueOf(source) + if v.Kind() != reflect.Ptr || v.IsNil() { + return nil + } + + el := v.Elem() + numField := el.NumField() + + for i := 0; i < numField; i++ { + field := el.Field(i) + if !field.CanInterface() { + continue + } + + switch field.Interface().(type) { + case *Key: + key = field.Interface().(*Key) + case Key: + ptr := field.Addr() + key = ptr.Interface().(*Key) + default: + continue + } + + return key + } + + return nil +} diff --git a/mongox/base/protection/key_test.go b/mongox/base/protection/key_test.go new file mode 100644 index 0000000..4b64458 --- /dev/null +++ b/mongox/base/protection/key_test.go @@ -0,0 +1,3 @@ +package protection_test + +// TODO: diff --git a/mongox/base/reset.go b/mongox/base/reset.go index 5cd2d52..8e03c33 100644 --- a/mongox/base/reset.go +++ b/mongox/base/reset.go @@ -7,7 +7,6 @@ import ( // Reset function creates new zero object for the target pointer func Reset(target interface{}) (created bool) { - type resetter interface { Reset() } diff --git a/mongox/base/stringbased/id.go b/mongox/base/stringbased/id.go index 58a6fa3..906833b 100644 --- a/mongox/base/stringbased/id.go +++ b/mongox/base/stringbased/id.go @@ -20,3 +20,17 @@ func (p *Primary) GetID() (id string) { func (p *Primary) SetID(id string) { p.ID = id } + +// New creates a new Primary structure with a defined _id +func New(id string) Primary { + return Primary{ID: id} +} + +func GetID(source mongox.StringBased) (id string, err error) { + id = source.GetID() + if id != "" { + return id, nil + } + + return "", mongox.ErrUninitializedBase +} diff --git a/mongox/base/stringbased/id_test.go b/mongox/base/stringbased/id_test.go index 5565867..0a2a61a 100644 --- a/mongox/base/stringbased/id_test.go +++ b/mongox/base/stringbased/id_test.go @@ -11,23 +11,21 @@ import ( ) func Test_GetID(t *testing.T) { - type DocWithString struct { stringbased.Primary `bson:",inline" json:",inline" collection:"1"` } - doc := &DocWithString{Primary: stringbased.Primary{ID: "foobar"}} + doc := &DocWithString{Primary: stringbased.New("foobar")} assert.Equal(t, "foobar", doc.GetID()) } func Test_SetID(t *testing.T) { - type DocWithString struct { stringbased.Primary `bson:",inline" json:",inline" collection:"1"` } - doc := &DocWithString{Primary: stringbased.Primary{ID: "foobar"}} + doc := &DocWithString{Primary: stringbased.New("foobar")} doc.SetID("rockrockrock") @@ -36,7 +34,6 @@ func Test_SetID(t *testing.T) { } func Test_SaveLoad(t *testing.T) { - type DocWithObjectID struct { stringbased.Primary `bson:",inline" json:",inline" collection:"1"` } @@ -46,9 +43,9 @@ func Test_SaveLoad(t *testing.T) { t.Fatal(err) } - defer db.Close() + defer func() { _ = db.Close() }() - doc1 := &DocWithObjectID{Primary: stringbased.Primary{ID: "foobar"}} + doc1 := &DocWithObjectID{Primary: stringbased.New("foobar")} doc2 := &DocWithObjectID{} err = db.SaveOne(doc1) @@ -66,12 +63,11 @@ func Test_SaveLoad(t *testing.T) { } func Test_Marshal(t *testing.T) { - type DocWithObjectID struct { stringbased.Primary `bson:",inline" json:",inline" collection:"1"` } - doc := &DocWithObjectID{Primary: stringbased.Primary{ID: "foobar"}} + doc := &DocWithObjectID{Primary: stringbased.New("foobar")} bytes, err := json.Marshal(doc) assert.NoError(t, err) diff --git a/mongox/database/context.go b/mongox/database/context.go new file mode 100644 index 0000000..6f5dc67 --- /dev/null +++ b/mongox/database/context.go @@ -0,0 +1,23 @@ +package database + +import ( + "context" +) + +type ctxDatabaseKey struct{} + +// GetFromContext function extracts the request data from context +func GetFromContext(ctx context.Context) (q *Database, ok bool) { + q, ok = ctx.Value(ctxDatabaseKey{}).(*Database) + if !ok { + return nil, false + } + + return q, true +} + +// WithContext creates the new context with a database attached +func WithContext(ctx context.Context, q *Database) (withQuery context.Context) { + db := NewDatabase(ctx, q.Client(), q.Name()) + return context.WithValue(ctx, ctxDatabaseKey{}, db) +} diff --git a/mongox/database/count.go b/mongox/database/count.go index d8711af..3ee5298 100644 --- a/mongox/database/count.go +++ b/mongox/database/count.go @@ -9,22 +9,29 @@ import ( // Count function counts documents in the database by query // target is used only to get collection by tag so it'd be better to use nil ptr here func (d *Database) Count(target interface{}, filters ...interface{}) (result int64, err error) { - composed, err := query.Compose(filters...) if err != nil { - return + return -1, err } - collection := d.GetCollectionOf(target) + collection, err := d.GetCollectionOf(target) + if err != nil { + return -1, err + } ctx := query.WithContext(d.Context(), composed) + m := composed.M() + opts := options.Count() opts.Limit = composed.Limiter() opts.Skip = composed.Skipper() - result, err = collection.CountDocuments(ctx, composed.M(), opts) + defer func() { _ = composed.OnClose().Invoke(ctx, target) }() - _ = composed.OnClose().Invoke(ctx, target) + result, err = collection.CountDocuments(ctx, m, opts) + if err != nil { + return -1, err + } - return + return result, nil } diff --git a/mongox/database/cursor.go b/mongox/database/cursor.go new file mode 100644 index 0000000..7ef7d95 --- /dev/null +++ b/mongox/database/cursor.go @@ -0,0 +1,188 @@ +package database + +import ( + "fmt" + "reflect" + "strconv" + "strings" + + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo/options" + + "github.com/mainnika/mongox-go-driver/v2/mongox" + "github.com/mainnika/mongox-go-driver/v2/mongox/query" +) + +func (d *Database) createCursor(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) { + _, hasPreloader := composed.Preloader() + if hasPreloader { + return d.createAggregateCursor(target, composed) + } + + return d.createSimpleCursor(target, composed) +} + +func (d *Database) createSimpleCursor(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) { + collection, err := d.GetCollectionOf(target) + if err != nil { + return nil, err + } + + opts := options.Find() + opts.Sort = composed.Sorter() + opts.Limit = composed.Limiter() + opts.Skip = composed.Skipper() + + ctx := d.Context() + m := composed.M() + + return collection.Find(ctx, m, opts) +} + +func (d *Database) createAggregateCursor(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) { + collection, err := d.GetCollectionOf(target) + if err != nil { + return nil, err + } + + pipeline := primitive.A{} + if !composed.Empty() { + pipeline = append(pipeline, primitive.M{"$match": composed.M()}) + } + if composed.Sorter() != nil { + pipeline = append(pipeline, primitive.M{"$sort": composed.Sorter()}) + } + if composed.Skipper() != nil { + pipeline = append(pipeline, primitive.M{"$skip": *composed.Skipper()}) + } + if composed.Limiter() != nil { + pipeline = append(pipeline, primitive.M{"$limit": *composed.Limiter()}) + } + + el := reflect.ValueOf(target) + elType := el.Type() + if elType.Kind() == reflect.Ptr { + elType = elType.Elem() + } + + numField := elType.NumField() + preloads, _ := composed.Preloader() + for i := 0; i < numField; i++ { + field := elType.Field(i) + tag := field.Tag + + preloadTag, ok := tag.Lookup("preload") + if !ok { + continue + } + jsonTag, _ := tag.Lookup("json") + if jsonTag == "-" { + return nil, fmt.Errorf("%w: private field is not preloadable", mongox.ErrMalformedBase) + } + + jsonData := strings.SplitN(jsonTag, ",", 2) + jsonName := field.Name + if len(jsonData) > 0 { + jsonName = strings.TrimSpace(jsonData[0]) + } + + preloadData := strings.Split(preloadTag, ",") + if len(preloadData) == 0 { + continue + } + if len(preloadData) == 1 { + return nil, fmt.Errorf("%w: foreign field is not specified", mongox.ErrMalformedBase) + } + + foreignField := strings.TrimSpace(preloadData[1]) + if len(foreignField) == 0 { + return nil, fmt.Errorf("%w: foreign field is empty", mongox.ErrMalformedBase) + } + localField := strings.TrimSpace(preloadData[0]) + if len(localField) == 0 { + localField = "_id" + } + + preloadLimiter := 100 + preloadReversed := false + if len(preloadData) > 2 { + stringLimit := strings.TrimSpace(preloadData[2]) + intLimit := preloadLimiter + + preloadReversed = strings.HasPrefix(stringLimit, "-") + if preloadReversed { + stringLimit = stringLimit[1:] + } + + intLimit, err = strconv.Atoi(stringLimit) + if err == nil { + preloadLimiter = intLimit + } else { + return nil, fmt.Errorf("%w: preload limit should be an integer", mongox.ErrMalformedBase) + } + } + + for _, preload := range preloads { + if preload != jsonName { + continue + } + + field := elType.Field(i) + fieldType := field.Type + + isSlice := fieldType.Kind() == reflect.Slice + if isSlice { + fieldType = fieldType.Elem() + } + + isPtr := fieldType.Kind() != reflect.Ptr + if isPtr { + return nil, fmt.Errorf("%w: preload field should have ptr type", mongox.ErrMalformedBase) + } + + lookupCollection, err := d.GetCollectionOf(reflect.Zero(fieldType).Interface()) + if err != nil { + return nil, err + } + + lookupVars := primitive.M{"selector": "$" + localField} + lookupPipeline := primitive.A{ + primitive.M{"$match": primitive.M{"$expr": primitive.M{"$eq": primitive.A{"$" + foreignField, "$$selector"}}}}, + } + + if preloadReversed { + lookupPipeline = append(lookupPipeline, primitive.M{"$sort": primitive.M{"_id": -1}}) + } + if isSlice && preloadLimiter > 0 { + lookupPipeline = append(lookupPipeline, primitive.M{"$limit": preloadLimiter}) + } else if !isSlice { + lookupPipeline = append(lookupPipeline, primitive.M{"$limit": 1}) + } + + pipeline = append(pipeline, primitive.M{ + "$lookup": primitive.M{ + "from": lookupCollection.Name(), + "let": lookupVars, + "pipeline": lookupPipeline, + "as": jsonName, + }, + }) + + if isSlice { + continue + } + + pipeline = append(pipeline, primitive.M{ + "$unwind": primitive.M{ + "preserveNullAndEmptyArrays": true, + "path": "$" + jsonName, + }, + }) + } + } + + ctx := d.Context() + opts := options.Aggregate() + + return collection.Aggregate(ctx, pipeline, opts) +} diff --git a/mongox/database/database.go b/mongox/database/database.go index 3d2816a..185e569 100644 --- a/mongox/database/database.go +++ b/mongox/database/database.go @@ -2,34 +2,27 @@ package database import ( "context" - "fmt" "reflect" - "strconv" - "strings" - - "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo/options" "github.com/mainnika/mongox-go-driver/v2/mongox" - "github.com/mainnika/mongox-go-driver/v2/mongox/query" ) // Database handler type Database struct { client *mongox.Client - dbname string + name string ctx context.Context } // NewDatabase function creates new database instance with mongo client and empty context -func NewDatabase(client *mongox.Client, dbname string) (db mongox.Database) { - +func NewDatabase(ctx context.Context, client *mongox.Client, name string) (db mongox.Database) { db = &Database{ client: client, - dbname: dbname, + name: name, + ctx: ctx, } - return + return db } // Client function returns a mongo client @@ -37,213 +30,42 @@ func (d *Database) Client() (client *mongox.Client) { return d.client } -// Context function returns a context -func (d *Database) Context() (ctx context.Context) { - - ctx = d.ctx - if ctx == nil { - ctx = context.Background() - } - - return -} - // Name function returns a database name func (d *Database) Name() (name string) { - return d.dbname + return d.name } -// New function creates new database context with same client -func (d *Database) New(ctx context.Context) (db mongox.Database) { - +// Context function returns a context +func (d *Database) Context() (ctx context.Context) { + ctx = d.ctx if ctx == nil { ctx = context.Background() } - db = &Database{ - client: d.client, - dbname: d.dbname, - ctx: ctx, - } - - return + return ctx } // GetCollectionOf returns the collection object by the «collection» tag of the given document; -// the «collection» tag should exists, e.g.: -// type Foobar struct { -// base.ObjectID `bson:",inline" json:",inline" collection:"foobars"` -// ... -// Will panic if there is no «collection» tag -func (d *Database) GetCollectionOf(document interface{}) (collection *mongox.Collection) { - +// +// example: +// type Foobar struct { +// base.ObjectID `bson:",inline" json:",inline" collection:"foobars"` +// ... +func (d *Database) GetCollectionOf(document interface{}) (collection *mongox.Collection, err error) { el := reflect.TypeOf(document).Elem() numField := el.NumField() + databaseName := d.name for i := 0; i < numField; i++ { field := el.Field(i) tag := field.Tag - found, ok := tag.Lookup("collection") - if !ok { + collectionName, found := tag.Lookup("collection") + if !found { continue } - return d.client.Database(d.dbname).Collection(found) - } - - panic(fmt.Errorf("document %v does not have a collection tag", document)) -} - -func (d *Database) createSimpleLoad(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) { - - collection := d.GetCollectionOf(target) - opts := options.Find() - - opts.Sort = composed.Sorter() - opts.Limit = composed.Limiter() - opts.Skip = composed.Skipper() - - return collection.Find(d.Context(), composed.M(), opts) -} - -func (d *Database) createAggregateLoad(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) { - - collection := d.GetCollectionOf(target) - opts := options.Aggregate() - - pipeline := primitive.A{} - - if !composed.Empty() { - pipeline = append(pipeline, primitive.M{"$match": composed.M()}) - } - if composed.Sorter() != nil { - pipeline = append(pipeline, primitive.M{"$sort": composed.Sorter()}) - } - if composed.Skipper() != nil { - pipeline = append(pipeline, primitive.M{"$skip": *composed.Skipper()}) - } - if composed.Limiter() != nil { - pipeline = append(pipeline, primitive.M{"$limit": *composed.Limiter()}) - } - - el := reflect.ValueOf(target) - elType := el.Type() - if elType.Kind() == reflect.Ptr { - elType = elType.Elem() - } - - numField := elType.NumField() - preloads, _ := composed.Preloader() - - for i := 0; i < numField; i++ { - - field := elType.Field(i) - tag := field.Tag - - preloadTag, ok := tag.Lookup("preload") - if !ok { - continue - } - jsonTag, _ := tag.Lookup("json") - if jsonTag == "-" { - panic(fmt.Errorf("preload private field is impossible")) - } - - jsonData := strings.SplitN(jsonTag, ",", 2) - jsonName := field.Name - if len(jsonData) > 0 { - jsonName = strings.TrimSpace(jsonData[0]) - } - - preloadData := strings.Split(preloadTag, ",") - if len(preloadData) == 0 { - continue - } - if len(preloadData) == 1 { - panic(fmt.Errorf("there is no foreign field")) - } - - localField := strings.TrimSpace(preloadData[0]) - if len(localField) == 0 { - localField = "_id" - } - - foreignField := strings.TrimSpace(preloadData[1]) - if len(foreignField) == 0 { - panic(fmt.Errorf("there is no foreign field")) - } - - preloadLimiter := 100 - preloadReversed := false - if len(preloadData) > 2 { - stringLimit := strings.TrimSpace(preloadData[2]) - intLimit := preloadLimiter - - preloadReversed = strings.HasPrefix(stringLimit, "-") - if preloadReversed { - stringLimit = stringLimit[1:] - } - - intLimit, err = strconv.Atoi(stringLimit) - if err == nil { - preloadLimiter = intLimit - } - } - - for _, preload := range preloads { - if preload != jsonName { - continue - } - - field := elType.Field(i) - fieldType := field.Type - - isSlice := fieldType.Kind() == reflect.Slice - if isSlice { - fieldType = fieldType.Elem() - } - - isPtr := fieldType.Kind() != reflect.Ptr - if isPtr { - panic(fmt.Errorf("preload field should have ptr type")) - } - - lookupCollection := d.GetCollectionOf(reflect.Zero(fieldType).Interface()) - lookupVars := primitive.M{"selector": "$" + localField} - lookupPipeline := primitive.A{ - primitive.M{"$match": primitive.M{"$expr": primitive.M{"$eq": primitive.A{"$" + foreignField, "$$selector"}}}}, - } - - if preloadReversed { - lookupPipeline = append(lookupPipeline, primitive.M{"$sort": primitive.M{"_id": -1}}) - } - if isSlice && preloadLimiter > 0 { - lookupPipeline = append(lookupPipeline, primitive.M{"$limit": preloadLimiter}) - } else if !isSlice { - lookupPipeline = append(lookupPipeline, primitive.M{"$limit": 1}) - } - - pipeline = append(pipeline, primitive.M{ - "$lookup": primitive.M{ - "from": lookupCollection.Name(), - "let": lookupVars, - "pipeline": lookupPipeline, - "as": jsonName, - }, - }) - - if isSlice { - continue - } - - pipeline = append(pipeline, primitive.M{ - "$unwind": primitive.M{ - "preserveNullAndEmptyArrays": true, - "path": "$" + jsonName, - }, - }) - } + return d.client.Database(databaseName).Collection(collectionName), nil } - return collection.Aggregate(d.Context(), pipeline, opts) + return nil, mongox.ErrNoCollection } diff --git a/mongox/database/deletearray.go b/mongox/database/deletearray.go index 2eb1521..d8d9ebc 100644 --- a/mongox/database/deletearray.go +++ b/mongox/database/deletearray.go @@ -13,7 +13,6 @@ import ( // DeleteArray removes documents list from a database by their ids func (d *Database) DeleteArray(target interface{}, filters ...interface{}) (err error) { - targetV := reflect.ValueOf(target) targetT := targetV.Type() @@ -36,42 +35,43 @@ func (d *Database) DeleteArray(target interface{}, filters ...interface{}) (err zeroElem := reflect.Zero(targetSliceElemT) targetLen := targetSliceV.Len() - composed, err := query.Compose(filters...) + collection, err := d.GetCollectionOf(zeroElem.Interface()) if err != nil { - return + return err } - collection := d.GetCollectionOf(zeroElem.Interface()) - ids := primitive.A{} - ctx := query.WithContext(d.Context(), composed) - - for i := 0; i < targetLen; i++ { - elem := targetSliceV.Index(i) - ids = append(ids, base.GetID(elem.Interface())) + composed, err := query.Compose(filters...) + if err != nil { + return err } - defer func() { - invokerr := composed.OnClose().Invoke(ctx, target) - if err == nil { - err = invokerr - } - - return - }() + if targetLen > 0 { + var ids primitive.A + for i := 0; i < targetLen; i++ { + elem := targetSliceV.Index(i) + elemID, err := base.GetID(elem.Interface()) + if err != nil { + return err + } - if len(ids) == 0 { - return fmt.Errorf("can't delete zero elements") + ids = append(ids, elemID) + } + composed.And(primitive.M{"_id": primitive.M{"$in": ids}}) } - composed.And(primitive.M{"_id": primitive.M{"$in": ids}}) + ctx := query.WithContext(d.Context(), composed) + m := composed.M() + opts := options.Delete() + + defer func() { _ = composed.OnClose().Invoke(ctx, target) }() - result, err := collection.DeleteMany(ctx, composed.M(), options.Delete()) + result, err := collection.DeleteMany(ctx, m, opts) if err != nil { - return + return fmt.Errorf("while deleting array: %w", err) } if result.DeletedCount != int64(targetLen) { - err = fmt.Errorf("can't verify delete result: removed count mismatch %d != %d", result.DeletedCount, targetLen) + return fmt.Errorf("deleted count mismatch %d != %d", result.DeletedCount, targetLen) } - return + return nil } diff --git a/mongox/database/deleteone.go b/mongox/database/deleteone.go index 658cc90..ca5e7fe 100644 --- a/mongox/database/deleteone.go +++ b/mongox/database/deleteone.go @@ -2,6 +2,7 @@ package database import ( "fmt" + "github.com/mainnika/mongox-go-driver/v2/mongox/base/protection" "github.com/modern-go/reflect2" "go.mongodb.org/mongo-driver/bson/primitive" @@ -13,51 +14,53 @@ import ( // DeleteOne removes a document from a database and then returns it into target func (d *Database) DeleteOne(target interface{}, filters ...interface{}) (err error) { - composed, err := query.Compose(filters...) if err != nil { - return + return err } - collection := d.GetCollectionOf(target) - protected := base.GetProtection(target) - ctx := query.WithContext(d.Context(), composed) - - opts := options.FindOneAndDelete() - opts.Sort = composed.Sorter() + collection, err := d.GetCollectionOf(target) + if err != nil { + return err + } if !reflect2.IsNil(target) { - composed.And(primitive.M{"_id": base.GetID(target)}) + targetID, err := base.GetID(target) + if err != nil { + return err + } + + composed.And(primitive.M{"_id": targetID}) } + protected := protection.Get(target) if protected != nil { - query.Push(composed, protected) + _, err := query.Push(composed, protected) + if err != nil { + return err + } + protected.Restate() } - defer func() { - invokerr := composed.OnClose().Invoke(ctx, target) - if err == nil { - err = invokerr - } + ctx := query.WithContext(d.Context(), composed) + m := composed.M() + opts := options.FindOneAndDelete() + opts.Sort = composed.Sorter() - return - }() + defer func() { _ = composed.OnClose().Invoke(ctx, target) }() - result := collection.FindOneAndDelete(ctx, composed.M(), opts) + result := collection.FindOneAndDelete(ctx, m, opts) if result.Err() != nil { return fmt.Errorf("can't create find one and delete result: %w", result.Err()) } err = result.Decode(target) if err != nil { - return + return fmt.Errorf("can't decode find one and delete result: %w", err) } - err = composed.OnDecode().Invoke(ctx, target) - if err != nil { - return - } + _ = composed.OnDecode().Invoke(ctx, target) - return + return nil } diff --git a/mongox/database/index.go b/mongox/database/index.go index e8c14f3..a4e95b5 100644 --- a/mongox/database/index.go +++ b/mongox/database/index.go @@ -14,18 +14,21 @@ import ( ) // IndexEnsure function ensures index in mongo collection of document -// `index:""` -- https://docs.mongodb.com/manual/indexes/#create-an-index -// `index:"-"` -- (descending) -// `index:"-,+foo,+-bar"` -- https://docs.mongodb.com/manual/core/index-compound -// `index:"-,+foo,+-bar,unique"` -- https://docs.mongodb.com/manual/core/index-unique -// `index:"-,+foo,+-bar,unique,allowNull"` -- https://docs.mongodb.com/manual/core/index-partial -// `index:"-,unique,allowNull,expireAfter=86400"` -- https://docs.mongodb.com/manual/core/index-ttl -// `index:"-,unique,allowNull,expireAfter={{.Expire}}"` -- evaluate index as a golang template with `cfg` arguments +// +// `index:""` -- https://docs.mongodb.com/manual/indexes/#create-an-index +// `index:"-"` -- (descending) +// `index:"-,+foo,+-bar"` -- https://docs.mongodb.com/manual/core/index-compound +// `index:"-,+foo,+-bar,unique"` -- https://docs.mongodb.com/manual/core/index-unique +// `index:"-,+foo,+-bar,unique,allowNull"` -- https://docs.mongodb.com/manual/core/index-partial +// `index:"-,unique,allowNull,expireAfter=86400"` -- https://docs.mongodb.com/manual/core/index-ttl +// `index:"-,unique,allowNull,expireAfter={{.Expire}}"` -- evaluate index as a golang template with `cfg` arguments func (d *Database) IndexEnsure(cfg interface{}, document interface{}) (err error) { - el := reflect.ValueOf(document).Elem().Type() numField := el.NumField() - documents := d.GetCollectionOf(document) + collection, err := d.GetCollectionOf(document) + if err != nil { + return err + } for i := 0; i < numField; i++ { @@ -126,7 +129,7 @@ func (d *Database) IndexEnsure(cfg interface{}, document interface{}) (err error } } - _, err = documents.Indexes().CreateOne(d.Context(), mongo.IndexModel{Keys: index, Options: opts}) + _, err = collection.Indexes().CreateOne(d.Context(), mongo.IndexModel{Keys: index, Options: opts}) if err != nil { return } diff --git a/mongox/database/index_test.go b/mongox/database/index_test.go index 9c570c6..2355d79 100644 --- a/mongox/database/index_test.go +++ b/mongox/database/index_test.go @@ -1,6 +1,7 @@ package database_test import ( + "github.com/stretchr/testify/require" "testing" "github.com/stretchr/testify/assert" @@ -142,7 +143,10 @@ func TestDatabase_Ensure(t *testing.T) { err = db.IndexEnsure(tt.settings, tt.doc) assert.NoError(t, err) - indexes, _ := db.GetCollectionOf(tt.doc).Indexes().List(db.Context()) + collection, err := db.GetCollectionOf(tt.doc) + require.NoError(t, err) + + indexes, _ := collection.Indexes().List(db.Context()) index := new(map[string]interface{}) indexes.Next(db.Context()) // skip _id_ diff --git a/mongox/database/loadarray.go b/mongox/database/loadarray.go index d52e8b7..02d7981 100644 --- a/mongox/database/loadarray.go +++ b/mongox/database/loadarray.go @@ -4,14 +4,12 @@ import ( "fmt" "reflect" - "github.com/mainnika/mongox-go-driver/v2/mongox" "github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/query" ) // LoadArray loads an array of documents from the database by query func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err error) { - targetV := reflect.ValueOf(target) targetT := targetV.Type() @@ -31,61 +29,36 @@ func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err er panic(fmt.Errorf("target slice should contain ptrs")) } + zeroElem := reflect.Zero(targetSliceElemT) + composed, err := query.Compose(filters...) if err != nil { - return + return err } - zeroElem := reflect.Zero(targetSliceElemT) - _, hasPreloader := composed.Preloader() ctx := query.WithContext(d.Context(), composed) - var result *mongox.Cursor - var i int - - defer func() { - - if result != nil { - closerr := result.Close(ctx) - if err == nil { - err = closerr - } - } - - invokerr := composed.OnClose().Invoke(ctx, target) - if err == nil { - err = invokerr - } - - return - }() + defer func() { _ = composed.OnClose().Invoke(ctx, target) }() - if hasPreloader { - result, err = d.createAggregateLoad(zeroElem.Interface(), composed) - } else { - result, err = d.createSimpleLoad(zeroElem.Interface(), composed) - } + cur, err := d.createCursor(zeroElem.Interface(), composed) if err != nil { - err = fmt.Errorf("can't create find result: %w", err) - return + return fmt.Errorf("can't create find result: %w", err) } - for i = 0; result.Next(ctx); i++ { + defer func() { _ = cur.Close(ctx) }() + var i int + for i = 0; cur.Next(ctx); i++ { var elem interface{} - if i == targetSliceV.Len() { value := reflect.New(targetSliceElemT.Elem()) elem = value.Interface() - err = composed.OnCreate().Invoke(ctx, elem) - if err != nil { - return - } + _ = composed.OnCreate().Invoke(ctx, elem) - err = result.Decode(elem) + err = cur.Decode(elem) if err != nil { - return + return err } targetSliceV = reflect.Append(targetSliceV, value) @@ -93,26 +66,24 @@ func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err er elem = targetSliceV.Index(i).Interface() if created := base.Reset(elem); created { - err = composed.OnCreate().Invoke(ctx, elem) - } - if err != nil { - return + _ = composed.OnCreate().Invoke(ctx, elem) } - err = result.Decode(elem) + err = cur.Decode(elem) if err != nil { - return + return err } } - err = composed.OnDecode().Invoke(ctx, elem) - if err != nil { - return - } + _ = composed.OnDecode().Invoke(ctx, elem) + } + err = cur.Err() + if err != nil { + return err } targetSliceV = targetSliceV.Slice(0, i) targetV.Elem().Set(targetSliceV) - return + return nil } diff --git a/mongox/database/loadone.go b/mongox/database/loadone.go index 216e696..4fc040b 100644 --- a/mongox/database/loadone.go +++ b/mongox/database/loadone.go @@ -1,8 +1,6 @@ package database import ( - "fmt" - "github.com/mainnika/mongox-go-driver/v2/mongox" "github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/query" @@ -10,68 +8,39 @@ import ( // LoadOne function loads a first single target document by a query func (d *Database) LoadOne(target interface{}, filters ...interface{}) (err error) { - composed, err := query.Compose(append(filters, query.Limit(1))...) if err != nil { - return + return err } - _, hasPreloader := composed.Preloader() ctx := query.WithContext(d.Context(), composed) - var result *mongox.Cursor - - defer func() { - - if result != nil { - closerr := result.Close(ctx) - if err == nil { - err = closerr - } - } - - invokerr := composed.OnClose().Invoke(ctx, target) - if err == nil { - err = invokerr - } - - return - }() + defer func() { _ = composed.OnClose().Invoke(ctx, target) }() - if hasPreloader { - result, err = d.createAggregateLoad(target, composed) - } else { - result, err = d.createSimpleLoad(target, composed) - } + cur, err := d.createCursor(target, composed) if err != nil { - return fmt.Errorf("can't create find result: %w", err) + return err } + defer func() { _ = cur.Close(ctx) }() - hasNext := result.Next(ctx) - if result.Err() != nil { - err = result.Err() - return + hasNext := cur.Next(ctx) + if cur.Err() != nil { + return cur.Err() } if !hasNext { return mongox.ErrNoDocuments } if created := base.Reset(target); created { - err = composed.OnCreate().Invoke(ctx, target) - } - if err != nil { - return + _ = composed.OnCreate().Invoke(ctx, target) } - err = result.Decode(target) + err = cur.Decode(target) if err != nil { - return + return err } - err = composed.OnDecode().Invoke(ctx, target) - if err != nil { - return - } + _ = composed.OnDecode().Invoke(ctx, target) - return + return nil } diff --git a/mongox/database/loadstream.go b/mongox/database/loadstream.go index cbeadb9..dcbcdcb 100644 --- a/mongox/database/loadstream.go +++ b/mongox/database/loadstream.go @@ -1,36 +1,25 @@ package database import ( - "fmt" - "github.com/mainnika/mongox-go-driver/v2/mongox" "github.com/mainnika/mongox-go-driver/v2/mongox/query" ) // LoadStream function loads documents one by one into a target channel func (d *Database) LoadStream(target interface{}, filters ...interface{}) (loader mongox.StreamLoader, err error) { - composed, err := query.Compose(filters...) if err != nil { return } - _, hasPreloader := composed.Preloader() ctx := query.WithContext(d.Context(), composed) - var cursor *mongox.Cursor - - if hasPreloader { - cursor, err = d.createAggregateLoad(target, composed) - } else { - cursor, err = d.createSimpleLoad(target, composed) - } + cur, err := d.createCursor(target, composed) if err != nil { - err = fmt.Errorf("can't create find result: %w", err) - return + return nil, err } - loader = &StreamLoader{cur: cursor, ctx: ctx, query: composed} + loader = &StreamLoader{cur: cur, ctx: ctx, query: composed} - return + return loader, nil } diff --git a/mongox/database/saveone.go b/mongox/database/saveone.go index 3fb6642..0f3c813 100644 --- a/mongox/database/saveone.go +++ b/mongox/database/saveone.go @@ -1,6 +1,7 @@ package database import ( + "github.com/mainnika/mongox-go-driver/v2/mongox/base/protection" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo/options" @@ -10,38 +11,37 @@ import ( // SaveOne saves a single source document to the database func (d *Database) SaveOne(source interface{}, filters ...interface{}) (err error) { + collection, err := d.GetCollectionOf(source) + if err != nil { + return err + } composed, err := query.Compose(filters...) if err != nil { - return + return err } - collection := d.GetCollectionOf(source) - id := base.GetID(source) - protected := base.GetProtection(source) - ctx := query.WithContext(d.Context(), composed) - + id, err := base.GetID(source) + if err != nil { + return err + } composed.And(primitive.M{"_id": id}) - opts := options.FindOneAndReplace() - opts.SetUpsert(true) - opts.SetReturnDocument(options.After) - + protected := protection.Get(source) if protected != nil { query.Push(composed, protected) protected.Restate() } - defer func() { - invokerr := composed.OnClose().Invoke(ctx, source) - if err == nil { - err = invokerr - } + ctx := query.WithContext(d.Context(), composed) + m := composed.M() + opts := options.FindOneAndReplace() + opts.SetUpsert(true) + opts.SetReturnDocument(options.After) - return - }() + defer func() { _ = composed.OnClose().Invoke(ctx, source) }() - result := collection.FindOneAndReplace(ctx, composed.M(), source, opts) + result := collection.FindOneAndReplace(ctx, m, source, opts) if result.Err() != nil { return result.Err() } diff --git a/mongox/database/streamloader.go b/mongox/database/streamloader.go index dbe8989..7b501cd 100644 --- a/mongox/database/streamloader.go +++ b/mongox/database/streamloader.go @@ -17,7 +17,6 @@ type StreamLoader struct { // DecodeNextMsg decodes the next document to an interface or returns an error func (l *StreamLoader) DecodeNextMsg(i interface{}) (err error) { - err = l.Next() if err != nil { return @@ -33,41 +32,35 @@ func (l *StreamLoader) DecodeNextMsg(i interface{}) (err error) { // DecodeMsg decodes the current cursor document into an interface func (l *StreamLoader) DecodeMsg(i interface{}) (err error) { - if created := base.Reset(i); created { - err = l.query.OnDecode().Invoke(l.ctx, i) + _ = l.query.OnDecode().Invoke(l.ctx, i) } if err != nil { - return + return err } err = l.cur.Decode(i) if err != nil { - return + return err } - err = l.query.OnDecode().Invoke(l.ctx, i) - if err != nil { - return - } + _ = l.query.OnDecode().Invoke(l.ctx, i) - return + return nil } // Next loads next documents but doesn't perform decoding func (l *StreamLoader) Next() (err error) { - hasNext := l.cur.Next(l.ctx) err = l.cur.Err() - if err != nil { - return + return err } if !hasNext { - err = mongox.ErrNoDocuments + return mongox.ErrNoDocuments } - return + return nil } // Cursor returns the underlying cursor @@ -77,24 +70,21 @@ func (l *StreamLoader) Cursor() (cursor *mongox.Cursor) { // Close stream loader and the underlying cursor func (l *StreamLoader) Close() (err error) { + defer func() { _ = l.query.OnClose().Invoke(l.ctx, nil) }() - closerr := l.cur.Close(l.ctx) - invokerr := l.query.OnClose().Invoke(l.ctx, nil) - - if closerr != nil { - err = closerr - return - } - - if invokerr != nil { - err = invokerr - return + err = l.cur.Close(l.ctx) + if err != nil { + return err } - return + return nil } // Err returns the last error func (l *StreamLoader) Err() (err error) { return l.cur.Err() } + +func (l *StreamLoader) Context() (ctx context.Context) { + return l.ctx +} diff --git a/mongox/database/updateone.go b/mongox/database/updateone.go index ee404b8..9f236bd 100644 --- a/mongox/database/updateone.go +++ b/mongox/database/updateone.go @@ -1,72 +1,64 @@ package database import ( + "github.com/mainnika/mongox-go-driver/v2/mongox/base/protection" "github.com/modern-go/reflect2" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo/options" - "github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/query" ) // UpdateOne updates a single document in the database and loads it into target func (d *Database) UpdateOne(target interface{}, filters ...interface{}) (err error) { - composed, err := query.Compose(filters...) if err != nil { - return + return err } - updaterDoc, err := composed.Updater() + update, err := composed.Updater() if err != nil { - return + return err } - collection := d.GetCollectionOf(target) - protected := base.GetProtection(target) - ctx := query.WithContext(d.Context(), composed) - - opts := options.FindOneAndUpdate() - opts.SetReturnDocument(options.After) - + protected := protection.Get(target) if protected != nil { if !protected.X.IsZero() { query.Push(composed, protected) } - protected.Restate() - setCmd, _ := updaterDoc["$set"].(primitive.M) + setCmd, _ := update["$set"].(primitive.M) if reflect2.IsNil(setCmd) { setCmd = primitive.M{} } - protected.PutToDocument(setCmd) - updaterDoc["$set"] = setCmd + protected.Inject(setCmd) + update["$set"] = setCmd } - defer func() { - invokerr := composed.OnClose().Invoke(ctx, target) - if err == nil { - err = invokerr - } + collection, err := d.GetCollectionOf(target) + if err != nil { + return err + } + + ctx := query.WithContext(d.Context(), composed) + m := composed.M() + opts := options.FindOneAndUpdate() + opts.SetReturnDocument(options.After) - return - }() + defer func() { _ = composed.OnClose().Invoke(ctx, target) }() - result := collection.FindOneAndUpdate(ctx, composed.M(), updaterDoc, opts) + result := collection.FindOneAndUpdate(ctx, m, update, opts) if result.Err() != nil { return result.Err() } err = result.Decode(target) if err != nil { - return + return err } - err = composed.OnDecode().Invoke(ctx, target) - if err != nil { - return - } + _ = composed.OnDecode().Invoke(ctx, target) - return + return nil } diff --git a/mongox/errors.go b/mongox/errors.go index 4f043a1..1873bcb 100644 --- a/mongox/errors.go +++ b/mongox/errors.go @@ -1,6 +1,8 @@ package mongox import ( + "errors" + "go.mongodb.org/mongo-driver/mongo" ) @@ -18,3 +20,9 @@ var ( ErrWrongClient = mongo.ErrWrongClient ErrNoDocuments = mongo.ErrNoDocuments ) + +var ( + ErrMalformedBase = errors.New("source contains malformed document base") + ErrUninitializedBase = errors.New("uninitialized document") + ErrNoCollection = errors.New("no collection found") +) diff --git a/mongox/mongox.go b/mongox/mongox.go index 97a09a1..c9c7c85 100644 --- a/mongox/mongox.go +++ b/mongox/mongox.go @@ -19,8 +19,7 @@ type Database interface { Client() (client *Client) Context() (context context.Context) Name() (name string) - New(ctx context.Context) (db Database) - GetCollectionOf(document interface{}) (collection *Collection) + GetCollectionOf(document interface{}) (collection *Collection, err error) Count(target interface{}, filters ...interface{}) (count int64, err error) DeleteArray(target interface{}, filters ...interface{}) (err error) DeleteOne(target interface{}, filters ...interface{}) (err error) @@ -54,8 +53,8 @@ type StringBased interface { SetID(id string) } -// JSONBased is an interface for documents that have object type for the _id field -type JSONBased interface { +// DocBased is an interface for documents that have object type for the _id field +type DocBased interface { GetID() (id primitive.D) SetID(id primitive.D) } diff --git a/mongox/query/callbacks.go b/mongox/query/callbacks.go index 3dc1437..2a1f5c9 100644 --- a/mongox/query/callbacks.go +++ b/mongox/query/callbacks.go @@ -18,9 +18,9 @@ func (c Callbacks) Invoke(ctx context.Context, iter interface{}) (err error) { for _, cb := range c { err = cb(ctx, iter) if err != nil { - return + return err } } - return + return nil } diff --git a/mongox/query/callbacks_test.go b/mongox/query/callbacks_test.go new file mode 100644 index 0000000..2d76902 --- /dev/null +++ b/mongox/query/callbacks_test.go @@ -0,0 +1,58 @@ +package query_test + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/mainnika/mongox-go-driver/v2/mongox/query" +) + +func TestCallbacks_InvokeOk(t *testing.T) { + + mocked := mock.Mock{} + + var callbacks = query.Callbacks{ + func(ctx context.Context, iter interface{}) (err error) { + return mocked.Called(ctx, iter).Error(0) + }, + func(ctx context.Context, iter interface{}) (err error) { + return mocked.Called(ctx, iter).Error(0) + }, + } + + ctx := context.Background() + iter := int64(42) + + mocked.On("func1", ctx, iter).Return(nil).Once() + mocked.On("func2", ctx, iter).Return(nil).Once() + + assert.NoError(t, callbacks.Invoke(ctx, iter)) + assert.True(t, mocked.AssertExpectations(t)) +} + +func TestCallbacks_InvokeStopIfError(t *testing.T) { + + mocked := mock.Mock{} + + var callbacks = query.Callbacks{ + func(ctx context.Context, iter interface{}) (err error) { + return mocked.Called(ctx, iter).Error(0) + }, + func(ctx context.Context, iter interface{}) (err error) { + t.FailNow() + return + }, + } + + ctx := context.Background() + iter := int(42) + + mocked.On("func1", ctx, iter).Return(fmt.Errorf("wat")) + + assert.EqualError(t, callbacks.Invoke(ctx, iter), "wat") + assert.True(t, mocked.AssertExpectations(t)) +} diff --git a/mongox/query/compose.go b/mongox/query/compose.go index 98f9990..edfa9cd 100644 --- a/mongox/query/compose.go +++ b/mongox/query/compose.go @@ -9,13 +9,11 @@ import ( "github.com/mainnika/mongox-go-driver/v2/mongox/base/protection" ) -type applyFilterFunc = func(query *Query, filter interface{}) (ok bool) +type applyFilterFunc func(query *Query, filter interface{}) (ok bool) // Compose is a function to compose filters into a single query func Compose(filters ...interface{}) (query *Query, err error) { - query = &Query{} - for _, filter := range filters { ok, err := Push(query, filter) if err != nil { @@ -26,25 +24,25 @@ func Compose(filters ...interface{}) (query *Query, err error) { } } - return + return query, nil } // Push applies single filter to a query func Push(query *Query, filter interface{}) (ok bool, err error) { - - ok = reflect2.IsNil(filter) - if ok { - return + emptyFilter := reflect2.IsNil(filter) + if emptyFilter { + return true, nil } validator, hasValidator := filter.(Validator) if hasValidator { - err = validator.Validate() - } - if err != nil { - return + err := validator.Validate() + if err != nil { + return false, fmt.Errorf("while validating filter %v, %w", filter, err) + } } + ok = false // true if at least one filter was applied for _, applier := range []applyFilterFunc{ applyBson, applyLimit, @@ -58,12 +56,11 @@ func Push(query *Query, filter interface{}) (ok bool, err error) { ok = applier(query, filter) || ok } - return + return ok, nil } // applyBson is a fallback for a custom primitive.M func applyBson(query *Query, filter interface{}) (ok bool) { - if filter, ok := filter.(primitive.M); ok { query.And(filter) return true @@ -74,7 +71,6 @@ func applyBson(query *Query, filter interface{}) (ok bool) { // applyLimits extends query with a limiter func applyLimit(query *Query, filter interface{}) (ok bool) { - if filter, ok := filter.(Limiter); ok { query.limiter = filter return true @@ -85,7 +81,6 @@ func applyLimit(query *Query, filter interface{}) (ok bool) { // applySort extends query with a sort rule func applySort(query *Query, filter interface{}) (ok bool) { - if filter, ok := filter.(Sorter); ok { query.sorter = filter return true @@ -96,7 +91,6 @@ func applySort(query *Query, filter interface{}) (ok bool) { // applySkip extends query with a skip number func applySkip(query *Query, filter interface{}) (ok bool) { - if filter, ok := filter.(Skipper); ok { query.skipper = filter return true @@ -106,15 +100,12 @@ func applySkip(query *Query, filter interface{}) (ok bool) { } func applyProtection(query *Query, filter interface{}) (ok bool) { - - var keyDoc = primitive.M{} - + keyDoc := primitive.M{} switch filter := filter.(type) { case protection.Key: - filter.PutToDocument(keyDoc) + filter.Inject(keyDoc) case *protection.Key: - filter.PutToDocument(keyDoc) - + filter.Inject(keyDoc) default: return false } @@ -125,7 +116,6 @@ func applyProtection(query *Query, filter interface{}) (ok bool) { } func applyPreloader(query *Query, filter interface{}) (ok bool) { - if filter, ok := filter.(Preloader); ok { query.preloader = filter return true @@ -135,7 +125,6 @@ func applyPreloader(query *Query, filter interface{}) (ok bool) { } func applyUpdater(query *Query, filter interface{}) (ok bool) { - if filter, ok := filter.(Updater); ok { query.updater = filter return true @@ -145,12 +134,11 @@ func applyUpdater(query *Query, filter interface{}) (ok bool) { } func applyCallbacks(query *Query, filter interface{}) (ok bool) { - switch callback := filter.(type) { case OnDecode: - query.ondecode = append(query.ondecode, Callback(callback)) + query.onDecode = append(query.onDecode, Callback(callback)) case OnClose: - query.onclose = append(query.onclose, Callback(callback)) + query.onClose = append(query.onClose, Callback(callback)) default: return false } diff --git a/mongox/query/compose_test.go b/mongox/query/compose_test.go new file mode 100644 index 0000000..3692357 --- /dev/null +++ b/mongox/query/compose_test.go @@ -0,0 +1,128 @@ +package query_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/bson/primitive" + + "github.com/mainnika/mongox-go-driver/v2/mongox/base/protection" + "github.com/mainnika/mongox-go-driver/v2/mongox/query" +) + +func TestPushBSON(t *testing.T) { + + q := &query.Query{} + + ok, err := query.Push(q, primitive.M{"foo": "bar"}) + + assert.True(t, ok) + assert.NoError(t, err) + assert.NotEmpty(t, q.M()) + assert.Len(t, q.M()["$and"], 1) + assert.Contains(t, q.M()["$and"], primitive.M{"foo": "bar"}) + + ok, err = query.Push(q, primitive.M{"bar": "foo"}) + + assert.True(t, ok) + assert.NoError(t, err) + assert.NotEmpty(t, q.M()) + assert.Len(t, q.M()["$and"], 2) + assert.Contains(t, q.M()["$and"], primitive.M{"foo": "bar"}) + assert.Contains(t, q.M()["$and"], primitive.M{"bar": "foo"}) +} + +func TestPushLimiter(t *testing.T) { + + q := &query.Query{} + lim := query.Limit(2) + + ok, err := query.Push(q, lim) + + assert.True(t, ok) + assert.NoError(t, err) + assert.NotNil(t, q.Limiter()) + assert.EqualValues(t, q.Limiter(), query.Limit(2).Limit()) +} + +func TestPushSorter(t *testing.T) { + + q := &query.Query{} + sort := query.Sort{"foo": 1} + + ok, err := query.Push(q, sort) + + assert.True(t, ok) + assert.NoError(t, err) + assert.NotNil(t, q.Sorter()) + assert.EqualValues(t, q.Sorter(), primitive.M{"foo": 1}) +} + +func TestPushSkipper(t *testing.T) { + + q := &query.Query{} + skip := query.Skip(66) + + ok, err := query.Push(q, skip) + + assert.True(t, ok) + assert.NoError(t, err) + assert.NotNil(t, q.Skipper()) + assert.EqualValues(t, q.Skipper(), query.Skip(66).Skip()) +} + +func TestPushProtection(t *testing.T) { + + t.Run("push protection key pointer", func(t *testing.T) { + q := &query.Query{} + protected := &protection.Key{V: 1, X: primitive.ObjectID{2}} + + ok, err := query.Push(q, protected) + + assert.True(t, ok) + assert.NoError(t, err) + assert.NotEmpty(t, q.M()["$and"]) + assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.ObjectID{2}, "_v": int64(1)}) + }) + + t.Run("push protection key struct", func(t *testing.T) { + q := &query.Query{} + protected := protection.Key{V: 1, X: primitive.ObjectID{2}} + + ok, err := query.Push(q, protected) + + assert.True(t, ok) + assert.NoError(t, err) + assert.NotEmpty(t, q.M()["$and"]) + assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.ObjectID{2}, "_v": int64(1)}) + }) + + t.Run("protection key is empty", func(t *testing.T) { + q := &query.Query{} + protected := &protection.Key{} + + ok, err := query.Push(q, protected) + + assert.True(t, ok) + assert.NoError(t, err) + assert.NotEmpty(t, q.M()["$and"]) + assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.M{"$exists": false}, "_v": primitive.M{"$exists": false}}) + }) +} + +func TestPushPreloader(t *testing.T) { + + q := &query.Query{} + preloader := query.Preload{"a", "b"} + + ok, err := query.Push(q, preloader) + + assert.True(t, ok) + assert.NoError(t, err) + + p, hasPreloader := q.Preloader() + + assert.NotNil(t, p) + assert.True(t, hasPreloader) + assert.EqualValues(t, p, query.Preload{"a", "b"}) +} diff --git a/mongox/query/context.go b/mongox/query/context.go index 07e5b8a..0ffca29 100644 --- a/mongox/query/context.go +++ b/mongox/query/context.go @@ -9,11 +9,14 @@ type ctxQueryKey struct{} // GetFromContext function extracts the request data from context func GetFromContext(ctx context.Context) (q *Query, ok bool) { q, ok = ctx.Value(ctxQueryKey{}).(*Query) - return + if !ok { + return nil, false + } + + return q, true } // WithContext function creates the new context with request data func WithContext(ctx context.Context, q *Query) (withQuery context.Context) { - withQuery = context.WithValue(ctx, ctxQueryKey{}, q) - return + return context.WithValue(ctx, ctxQueryKey{}, q) } diff --git a/mongox/query/limit.go b/mongox/query/limit.go index 349ee2d..2511d20 100644 --- a/mongox/query/limit.go +++ b/mongox/query/limit.go @@ -12,13 +12,12 @@ var _ Limiter = Limit(0) // Limit returns a limit func (l Limit) Limit() (limit *int64) { - if l <= 0 { - return + return nil } limit = new(int64) *limit = int64(l) - return + return limit } diff --git a/mongox/query/preload.go b/mongox/query/preload.go index 026aab9..c985075 100644 --- a/mongox/query/preload.go +++ b/mongox/query/preload.go @@ -1,11 +1,11 @@ package query -// Preloader is a filter to skip the result +// Preloader is a filter to preload the result type Preloader interface { Preload() (preloads []string) } -// Preload is a simple implementation of the Skipper filter +// Preload is a simple implementation of the Preloader filter type Preload []string var _ Preloader = Preload{} diff --git a/mongox/query/query.go b/mongox/query/query.go index be605da..89e6455 100644 --- a/mongox/query/query.go +++ b/mongox/query/query.go @@ -15,35 +15,31 @@ type Query struct { skipper Skipper preloader Preloader updater Updater - ondecode Callbacks - onclose Callbacks - oncreate Callbacks + onDecode Callbacks + onClose Callbacks + onCreate Callbacks } // And function pushes the elem query to the $and array of the query func (q *Query) And(elem primitive.M) (query *Query) { - if q.m == nil { q.m = primitive.M{} } queries, exists := q.m["$and"].(primitive.A) - if !exists { q.m["$and"] = primitive.A{elem} return q } q.m["$and"] = append(queries, elem) - return q } // Limiter returns limiter value or nil func (q *Query) Limiter() (limit *int64) { - if q.limiter == nil { - return + return nil } return q.limiter.Limit() @@ -51,9 +47,8 @@ func (q *Query) Limiter() (limit *int64) { // Sorter is a sort rule for a query func (q *Query) Sorter() (sort interface{}) { - if q.sorter == nil { - return + return nil } return q.sorter.Sort() @@ -61,9 +56,8 @@ func (q *Query) Sorter() (sort interface{}) { // Skipper is a skipper for a query func (q *Query) Skipper() (skip *int64) { - if q.skipper == nil { - return + return nil } return q.skipper.Skip() @@ -71,62 +65,57 @@ func (q *Query) Skipper() (skip *int64) { // Updater is an update command for a query func (q *Query) Updater() (update primitive.M, err error) { - if q.updater == nil { - update = primitive.M{} - return + return primitive.M{}, nil } update = q.updater.Update() - if reflect2.IsNil(update) { - update = primitive.M{} - return + return primitive.M{}, nil } buffer := bytebufferpool.Get() defer bytebufferpool.Put(buffer) // convert update document to bson map values + buffer.Reset() bsonBytes, err := bson.MarshalAppend(buffer.B, update) if err != nil { - return + return primitive.M{}, err } - update = primitive.M{} + update = primitive.M{} // reset update map and unmarshal bson bytes to it again err = bson.Unmarshal(bsonBytes, update) if err != nil { - return + return primitive.M{}, err } - return + return update, nil } // Preloader is a preloader list for a query func (q *Query) Preloader() (preloads []string, ok bool) { - if q.preloader == nil { return nil, false } preloads = q.preloader.Preload() - ok = len(preloads) > 0 - return + return preloads, len(preloads) > 0 } // OnDecode callback is called after the mongo decode function func (q *Query) OnDecode() (callbacks Callbacks) { - return q.ondecode + return q.onDecode } // OnClose callback is called after the mongox ends a loading procedure func (q *Query) OnClose() (callbacks Callbacks) { - return q.onclose + return q.onClose } // OnCreate callback is called if the mongox creates a new document instance during loading func (q *Query) OnCreate() (callbacks Callbacks) { - return q.onclose + return q.onClose } // Empty checks the query for any content @@ -138,3 +127,8 @@ func (q *Query) Empty() (isEmpty bool) { func (q *Query) M() (m primitive.M) { return q.m } + +// New creates a new query +func New() (query *Query) { + return &Query{} +} diff --git a/mongox/query/skip.go b/mongox/query/skip.go index 9b889e2..ca78419 100644 --- a/mongox/query/skip.go +++ b/mongox/query/skip.go @@ -12,13 +12,12 @@ var _ Skipper = Skip(0) // Skip returns a skip number func (l Skip) Skip() (skip *int64) { - if l <= 0 { - return + return nil } skip = new(int64) *skip = int64(l) - return + return skip }