diff --git a/mongox/common/common.go b/mongox/common/common.go new file mode 100644 index 0000000..a142629 --- /dev/null +++ b/mongox/common/common.go @@ -0,0 +1,140 @@ +package common + +import ( + "reflect" + "strconv" + "strings" + + "github.com/mainnika/mongox-go-driver/mongox" + "github.com/mainnika/mongox-go-driver/mongox/errors" + "github.com/mainnika/mongox-go-driver/mongox/query" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +func createSimpleLoad(db *mongox.Database, target interface{}, composed *query.Query) (cursor *mongo.Cursor, err error) { + + collection := db.GetCollectionOf(target) + opts := options.Find() + + opts.Sort = composed.Sorter() + opts.Limit = composed.Limiter() + opts.Skip = composed.Skipper() + + return collection.Find(db.Context(), composed.M(), opts) +} + +func createAggregateLoad(db *mongox.Database, target interface{}, composed *query.Query) (cursor *mongo.Cursor, err error) { + + collection := db.GetCollectionOf(target) + opts := options.Aggregate() + + pipelineHead := primitive.A{primitive.M{"$match": composed.M()}} + pipelineTail := primitive.A{} + + el := reflect.ValueOf(target).Elem() + elType := el.Type() + numField := elType.NumField() + _, preloads := composed.Preloader() + + for i := 0; i < numField; i++ { + + field := elType.Field(i) + tag := field.Tag + + preloadTag, ok := tag.Lookup("preload") + if !ok { + continue + } + jsonTag, ok := tag.Lookup("json") + if jsonTag == "-" { + return nil, errors.Malformedf("preload private field is impossible") + } + + jsonData := strings.SplitN(jsonTag, ",", 2) + jsonName := field.Name + if len(jsonData) > 0 { + jsonName = strings.TrimSpace(jsonData[0]) + } + + preloadData := strings.Split(preloadTag, ",") + if len(preloadData) == 0 { + continue + } + if len(preloadData) == 1 { + panic("there is no foreign field") + } + + preloadName := strings.TrimSpace(preloadData[0]) + if len(preloadName) == 0 { + preloadName = jsonName + } + + foreignField := strings.TrimSpace(preloadData[1]) + if len(foreignField) == 0 { + panic("there is no foreign field") + } + + preloadLimiter := 100 + if len(preloadData) > 2 { + + stringLimit := strings.TrimSpace(preloadData[2]) + intLimit := preloadLimiter + + intLimit, err = strconv.Atoi(stringLimit) + if err == nil { + preloadLimiter = intLimit + } + } + + for _, preload := range preloads { + if preload != preloadName { + continue + } + + isPtr := el.Field(i).Kind() == reflect.Ptr + isSlice := el.Field(i).Kind() == reflect.Slice + isIface := el.Field(i).CanInterface() + if (!isPtr && !isSlice) || !isIface { + continue + } + + typ := el.Field(i).Type() + lookupCollection := db.GetCollectionOf(reflect.Zero(typ).Interface()) + lookupVars := primitive.M{"selector": "$_id"} + lookupPipeline := primitive.A{ + // todo: make match from composed query + primitive.M{"$match": primitive.M{"$expr": primitive.M{"$eq": primitive.A{"$" + foreignField, "$$selector"}}}}, + } + + if isSlice && preloadLimiter > 0 { + lookupPipeline = append(lookupPipeline, primitive.M{"$limit": preloadLimiter}) + } else if !isSlice { + lookupPipeline = append(lookupPipeline, primitive.M{"$limit": 1}) + } + + pipelineTail = append(pipelineTail, primitive.M{ + "$lookup": primitive.M{ + "from": lookupCollection.Name(), + "let": lookupVars, + "pipeline": lookupPipeline, + "as": jsonName, + }, + }) + + if isSlice { + continue + } + + pipelineTail = append(pipelineTail, primitive.M{ + "$unwind": primitive.M{ + "preserveNullAndEmptyArrays": true, + "path": "$" + jsonName, + }, + }) + } + } + + return collection.Aggregate(db.Context(), append(pipelineHead, pipelineTail...), opts) +} diff --git a/mongox/common/loadarray.go b/mongox/common/loadarray.go index ffcfd3b..aefb3fe 100644 --- a/mongox/common/loadarray.go +++ b/mongox/common/loadarray.go @@ -6,7 +6,7 @@ import ( "github.com/mainnika/mongox-go-driver/mongox" "github.com/mainnika/mongox-go-driver/mongox/errors" "github.com/mainnika/mongox-go-driver/mongox/query" - "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo" ) // LoadArray loads an array of documents from the database by query @@ -31,16 +31,17 @@ func LoadArray(db *mongox.Database, target interface{}, filters ...interface{}) panic(errors.InternalErrorf("target slice should contain ptrs")) } - dummy := reflect.Zero(targetSliceElemT) - collection := db.GetCollectionOf(dummy.Interface()) - opts := options.Find() composed := query.Compose(filters...) + hasPreloader, _ := composed.Preloader() - opts.Sort = composed.Sorter() - opts.Limit = composed.Limiter() - opts.Skip = composed.Skipper() + var result *mongo.Cursor + var err error - result, err := collection.Find(db.Context(), composed.M(), opts) + if hasPreloader { + result, err = createAggregateLoad(db, target, composed) + } else { + result, err = createSimpleLoad(db, target, composed) + } if err != nil { return errors.InternalErrorf("can't create find result: %s", err) } diff --git a/mongox/common/loadmany.go b/mongox/common/loadmany.go index aa30c3f..d568b7a 100644 --- a/mongox/common/loadmany.go +++ b/mongox/common/loadmany.go @@ -7,7 +7,6 @@ import ( "github.com/mainnika/mongox-go-driver/mongox/errors" "github.com/mainnika/mongox-go-driver/mongox/query" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" ) // ManyLoader is a controller for a database cursor @@ -43,15 +42,17 @@ func (l *ManyLoader) Close() error { // LoadMany function loads documents one by one into a target channel func LoadMany(db *mongox.Database, target interface{}, filters ...interface{}) (*ManyLoader, error) { - collection := db.GetCollectionOf(target) - opts := options.Find() - composed := query.Compose(filters...) + var cursor *mongo.Cursor + var err error - opts.Sort = composed.Sorter() - opts.Limit = composed.Limiter() - opts.Skip = composed.Skipper() + composed := query.Compose(filters...) + hasPreloader, _ := composed.Preloader() - cursor, err := collection.Find(db.Context(), composed.M(), opts) + if hasPreloader { + cursor, err = createAggregateLoad(db, target, composed) + } else { + cursor, err = createSimpleLoad(db, target, composed) + } if err != nil { return nil, errors.InternalErrorf("can't create find result: %s", err) } diff --git a/mongox/query/compose.go b/mongox/query/compose.go index 7c01602..9defeee 100644 --- a/mongox/query/compose.go +++ b/mongox/query/compose.go @@ -30,6 +30,7 @@ func Push(q *Query, f interface{}) bool { ok = ok || applySort(q, f) ok = ok || applySkip(q, f) ok = ok || applyProtection(q, f) + ok = ok || applyPreloader(q, f) return ok } @@ -108,3 +109,13 @@ func applyProtection(q *Query, f interface{}) bool { return true } + +func applyPreloader(q *Query, f interface{}) bool { + + if f, ok := f.(Preloader); ok { + q.preloader = f + return true + } + + return false +} diff --git a/mongox/query/preload.go b/mongox/query/preload.go new file mode 100644 index 0000000..6d6ad18 --- /dev/null +++ b/mongox/query/preload.go @@ -0,0 +1,17 @@ +package query + +// Preloader is a filter to skip the result +type Preloader interface { + Preload() []string +} + +// Preload is a simple implementation of the Skipper filter +type Preload []string + +var _ Preloader = Preload{} + +// Preload returns a preload list +func (l Preload) Preload() []string { + + return Preload(l) +} diff --git a/mongox/query/query.go b/mongox/query/query.go index 8438bdc..3b48dbd 100644 --- a/mongox/query/query.go +++ b/mongox/query/query.go @@ -8,10 +8,11 @@ import ( // Query is an enchanched bson.M map type Query struct { - m bson.M - limiter Limiter - sorter Sorter - skipper Skipper + m bson.M + limiter Limiter + sorter Sorter + skipper Skipper + preloader Preloader } // And function pushes the elem query to the $and array of the query @@ -63,6 +64,22 @@ func (q *Query) Skipper() *int64 { return q.skipper.Skip() } +// Preloader is a preloader list for a query +func (q *Query) Preloader() (empty bool, preloader []string) { + + if q.preloader == nil { + return false, nil + } + + preloader = q.preloader.Preload() + + if len(preloader) == 0 { + return false, nil + } + + return true, preloader +} + // Empty checks the query for any content func (q *Query) Empty() bool {