Merge common functions into database interface

v2
Nikita Tokarchuk 5 years ago
parent cf23c3b579
commit 7f46008227
  1. 162
      mongox/common/common.go
  2. 9
      mongox/database/count.go
  3. 152
      mongox/database/database.go
  4. 9
      mongox/database/deletearray.go
  5. 9
      mongox/database/deleteone.go
  6. 13
      mongox/database/loadarray.go
  7. 11
      mongox/database/loadone.go
  8. 33
      mongox/database/loadstream.go
  9. 9
      mongox/database/saveone.go
  10. 41
      mongox/database/streamloader.go
  11. 16
      mongox/mongox.go

@ -1,162 +0,0 @@
package common
import (
"fmt"
"reflect"
"strconv"
"strings"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
)
func createSimpleLoad(db mongox.Database, target interface{}, composed *query.Query) (cursor *mongo.Cursor, err error) {
collection := db.GetCollectionOf(target)
opts := options.Find()
opts.Sort = composed.Sorter()
opts.Limit = composed.Limiter()
opts.Skip = composed.Skipper()
return collection.Find(db.Context(), composed.M(), opts)
}
func createAggregateLoad(db mongox.Database, target interface{}, composed *query.Query) (cursor *mongo.Cursor, err error) {
collection := db.GetCollectionOf(target)
opts := options.Aggregate()
pipeline := primitive.A{}
if !composed.Empty() {
pipeline = append(pipeline, primitive.M{"$match": primitive.M{"$expr": composed.M()}})
}
if composed.Sorter() != nil {
pipeline = append(pipeline, primitive.M{"$sort": composed.Sorter()})
}
if composed.Skipper() != nil {
pipeline = append(pipeline, primitive.M{"$skip": *composed.Skipper()})
}
if composed.Limiter() != nil {
pipeline = append(pipeline, primitive.M{"$limit": *composed.Limiter()})
}
el := reflect.ValueOf(target).Elem()
elType := el.Type()
numField := elType.NumField()
_, preloads := composed.Preloader()
for i := 0; i < numField; i++ {
field := elType.Field(i)
tag := field.Tag
preloadTag, ok := tag.Lookup("preload")
if !ok {
continue
}
jsonTag, ok := tag.Lookup("json")
if jsonTag == "-" {
return nil, fmt.Errorf("preload private field is impossible")
}
jsonData := strings.SplitN(jsonTag, ",", 2)
jsonName := field.Name
if len(jsonData) > 0 {
jsonName = strings.TrimSpace(jsonData[0])
}
preloadData := strings.Split(preloadTag, ",")
if len(preloadData) == 0 {
continue
}
if len(preloadData) == 1 {
panic("there is no foreign field")
}
localField := strings.TrimSpace(preloadData[0])
if len(localField) == 0 {
localField = "_id"
}
foreignField := strings.TrimSpace(preloadData[1])
if len(foreignField) == 0 {
panic("there is no foreign field")
}
preloadLimiter := 100
preloadReversed := false
if len(preloadData) > 2 {
stringLimit := strings.TrimSpace(preloadData[2])
intLimit := preloadLimiter
preloadReversed = strings.HasPrefix(stringLimit, "-")
if preloadReversed {
stringLimit = stringLimit[1:]
}
intLimit, err = strconv.Atoi(stringLimit)
if err == nil {
preloadLimiter = intLimit
}
}
for _, preload := range preloads {
if preload != jsonName {
continue
}
isSlice := el.Field(i).Kind() == reflect.Slice
typ := el.Field(i).Type()
if typ.Kind() == reflect.Slice {
typ = typ.Elem()
}
if typ.Kind() != reflect.Ptr {
panic("preload field should have ptr type")
}
lookupCollection := db.GetCollectionOf(reflect.Zero(typ).Interface())
lookupVars := primitive.M{"selector": "$" + localField}
lookupPipeline := primitive.A{
primitive.M{"$match": primitive.M{"$expr": primitive.M{"$eq": primitive.A{"$" + foreignField, "$$selector"}}}},
}
if preloadReversed {
lookupPipeline = append(lookupPipeline, primitive.M{"$sort": primitive.M{"_id": -1}})
}
if isSlice && preloadLimiter > 0 {
lookupPipeline = append(lookupPipeline, primitive.M{"$limit": preloadLimiter})
} else if !isSlice {
lookupPipeline = append(lookupPipeline, primitive.M{"$limit": 1})
}
pipeline = append(pipeline, primitive.M{
"$lookup": primitive.M{
"from": lookupCollection.Name(),
"let": lookupVars,
"pipeline": lookupPipeline,
"as": jsonName,
},
})
if isSlice {
continue
}
pipeline = append(pipeline, primitive.M{
"$unwind": primitive.M{
"preserveNullAndEmptyArrays": true,
"path": "$" + jsonName,
},
})
}
}
return collection.Aggregate(db.Context(), pipeline, opts)
}

@ -1,4 +1,4 @@
package common package database
import ( import (
"fmt" "fmt"
@ -6,22 +6,21 @@ import (
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/query" "github.com/mainnika/mongox-go-driver/v2/mongox/query"
) )
// Count function counts documents in the database by query // Count function counts documents in the database by query
// target is used only to get collection by tag so it'd be better to use nil ptr here // target is used only to get collection by tag so it'd be better to use nil ptr here
func Count(db mongox.Database, target interface{}, filters ...interface{}) (int64, error) { func (d *Database) Count(target interface{}, filters ...interface{}) (int64, error) {
collection := db.GetCollectionOf(target) collection := d.GetCollectionOf(target)
opts := options.Count() opts := options.Count()
composed := query.Compose(filters...) composed := query.Compose(filters...)
opts.Limit = composed.Limiter() opts.Limit = composed.Limiter()
opts.Skip = composed.Skipper() opts.Skip = composed.Skipper()
result, err := collection.CountDocuments(db.Context(), composed.M(), opts) result, err := collection.CountDocuments(d.Context(), composed.M(), opts)
if err == mongo.ErrNoDocuments { if err == mongo.ErrNoDocuments {
return 0, err return 0, err
} }

@ -4,10 +4,15 @@ import (
"context" "context"
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"strings"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"github.com/mainnika/mongox-go-driver/v2/mongox" "github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
) )
// Database handler // Database handler
@ -80,3 +85,150 @@ func (d *Database) GetCollectionOf(document interface{}) mongox.MongoCollection
panic(fmt.Errorf("document %v does not have a collection tag", document)) panic(fmt.Errorf("document %v does not have a collection tag", document))
} }
func (d *Database) createSimpleLoad(target interface{}, composed *query.Query) (cursor *mongo.Cursor, err error) {
collection := d.GetCollectionOf(target)
opts := options.Find()
opts.Sort = composed.Sorter()
opts.Limit = composed.Limiter()
opts.Skip = composed.Skipper()
return collection.Find(d.Context(), composed.M(), opts)
}
func (d *Database) createAggregateLoad(target interface{}, composed *query.Query) (cursor *mongo.Cursor, err error) {
collection := d.GetCollectionOf(target)
opts := options.Aggregate()
pipeline := primitive.A{}
if !composed.Empty() {
pipeline = append(pipeline, primitive.M{"$match": primitive.M{"$expr": composed.M()}})
}
if composed.Sorter() != nil {
pipeline = append(pipeline, primitive.M{"$sort": composed.Sorter()})
}
if composed.Skipper() != nil {
pipeline = append(pipeline, primitive.M{"$skip": *composed.Skipper()})
}
if composed.Limiter() != nil {
pipeline = append(pipeline, primitive.M{"$limit": *composed.Limiter()})
}
el := reflect.ValueOf(target).Elem()
elType := el.Type()
numField := elType.NumField()
_, preloads := composed.Preloader()
for i := 0; i < numField; i++ {
field := elType.Field(i)
tag := field.Tag
preloadTag, ok := tag.Lookup("preload")
if !ok {
continue
}
jsonTag, ok := tag.Lookup("json")
if jsonTag == "-" {
return nil, fmt.Errorf("preload private field is impossible")
}
jsonData := strings.SplitN(jsonTag, ",", 2)
jsonName := field.Name
if len(jsonData) > 0 {
jsonName = strings.TrimSpace(jsonData[0])
}
preloadData := strings.Split(preloadTag, ",")
if len(preloadData) == 0 {
continue
}
if len(preloadData) == 1 {
panic("there is no foreign field")
}
localField := strings.TrimSpace(preloadData[0])
if len(localField) == 0 {
localField = "_id"
}
foreignField := strings.TrimSpace(preloadData[1])
if len(foreignField) == 0 {
panic("there is no foreign field")
}
preloadLimiter := 100
preloadReversed := false
if len(preloadData) > 2 {
stringLimit := strings.TrimSpace(preloadData[2])
intLimit := preloadLimiter
preloadReversed = strings.HasPrefix(stringLimit, "-")
if preloadReversed {
stringLimit = stringLimit[1:]
}
intLimit, err = strconv.Atoi(stringLimit)
if err == nil {
preloadLimiter = intLimit
}
}
for _, preload := range preloads {
if preload != jsonName {
continue
}
isSlice := el.Field(i).Kind() == reflect.Slice
typ := el.Field(i).Type()
if typ.Kind() == reflect.Slice {
typ = typ.Elem()
}
if typ.Kind() != reflect.Ptr {
panic("preload field should have ptr type")
}
lookupCollection := d.GetCollectionOf(reflect.Zero(typ).Interface())
lookupVars := primitive.M{"selector": "$" + localField}
lookupPipeline := primitive.A{
primitive.M{"$match": primitive.M{"$expr": primitive.M{"$eq": primitive.A{"$" + foreignField, "$$selector"}}}},
}
if preloadReversed {
lookupPipeline = append(lookupPipeline, primitive.M{"$sort": primitive.M{"_id": -1}})
}
if isSlice && preloadLimiter > 0 {
lookupPipeline = append(lookupPipeline, primitive.M{"$limit": preloadLimiter})
} else if !isSlice {
lookupPipeline = append(lookupPipeline, primitive.M{"$limit": 1})
}
pipeline = append(pipeline, primitive.M{
"$lookup": primitive.M{
"from": lookupCollection.Name(),
"let": lookupVars,
"pipeline": lookupPipeline,
"as": jsonName,
},
})
if isSlice {
continue
}
pipeline = append(pipeline, primitive.M{
"$unwind": primitive.M{
"preserveNullAndEmptyArrays": true,
"path": "$" + jsonName,
},
})
}
}
return collection.Aggregate(d.Context(), pipeline, opts)
}

@ -1,4 +1,4 @@
package common package database
import ( import (
"fmt" "fmt"
@ -7,12 +7,11 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/base"
) )
// DeleteArray removes documents list from a database by their ids // DeleteArray removes documents list from a database by their ids
func DeleteArray(db mongox.Database, target interface{}) error { func (d *Database) DeleteArray(target interface{}) error {
targetV := reflect.ValueOf(target) targetV := reflect.ValueOf(target)
targetT := targetV.Type() targetT := targetV.Type()
@ -35,7 +34,7 @@ func DeleteArray(db mongox.Database, target interface{}) error {
zeroElem := reflect.Zero(targetSliceElemT) zeroElem := reflect.Zero(targetSliceElemT)
targetLen := targetSliceV.Len() targetLen := targetSliceV.Len()
collection := db.GetCollectionOf(zeroElem.Interface()) collection := d.GetCollectionOf(zeroElem.Interface())
opts := options.Delete() opts := options.Delete()
ids := primitive.A{} ids := primitive.A{}
@ -48,7 +47,7 @@ func DeleteArray(db mongox.Database, target interface{}) error {
return fmt.Errorf("can't delete zero elements") return fmt.Errorf("can't delete zero elements")
} }
result, err := collection.DeleteMany(db.Context(), primitive.M{"_id": primitive.M{"$in": ids}}, opts) result, err := collection.DeleteMany(d.Context(), primitive.M{"_id": primitive.M{"$in": ids}}, opts)
if err != nil { if err != nil {
return fmt.Errorf("can't create find and delete result: %w", err) return fmt.Errorf("can't create find and delete result: %w", err)
} }

@ -1,4 +1,4 @@
package common package database
import ( import (
"fmt" "fmt"
@ -8,15 +8,14 @@ import (
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/base"
"github.com/mainnika/mongox-go-driver/v2/mongox/query" "github.com/mainnika/mongox-go-driver/v2/mongox/query"
) )
// DeleteOne removes a document from a database and then returns it into target // DeleteOne removes a document from a database and then returns it into target
func DeleteOne(db mongox.Database, target interface{}, filters ...interface{}) error { func (d *Database) DeleteOne(target interface{}, filters ...interface{}) error {
collection := db.GetCollectionOf(target) collection := d.GetCollectionOf(target)
opts := &options.FindOneAndDeleteOptions{} opts := &options.FindOneAndDeleteOptions{}
composed := query.Compose(filters...) composed := query.Compose(filters...)
protected := base.GetProtection(target) protected := base.GetProtection(target)
@ -33,7 +32,7 @@ func DeleteOne(db mongox.Database, target interface{}, filters ...interface{}) e
protected.V = time.Now().Unix() protected.V = time.Now().Unix()
} }
result := collection.FindOneAndDelete(db.Context(), composed.M(), opts) result := collection.FindOneAndDelete(d.Context(), composed.M(), opts)
if result.Err() != nil { if result.Err() != nil {
return fmt.Errorf("can't create find one and delete result: %w", result.Err()) return fmt.Errorf("can't create find one and delete result: %w", result.Err())
} }

@ -1,4 +1,4 @@
package common package database
import ( import (
"fmt" "fmt"
@ -6,13 +6,12 @@ import (
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/base"
"github.com/mainnika/mongox-go-driver/v2/mongox/query" "github.com/mainnika/mongox-go-driver/v2/mongox/query"
) )
// LoadArray loads an array of documents from the database by query // LoadArray loads an array of documents from the database by query
func LoadArray(db mongox.Database, target interface{}, filters ...interface{}) error { func (d *Database) LoadArray(target interface{}, filters ...interface{}) error {
targetV := reflect.ValueOf(target) targetV := reflect.ValueOf(target)
targetT := targetV.Type() targetT := targetV.Type()
@ -41,18 +40,18 @@ func LoadArray(db mongox.Database, target interface{}, filters ...interface{}) e
var err error var err error
if hasPreloader { if hasPreloader {
result, err = createAggregateLoad(db, zeroElem.Interface(), composed) result, err = d.createAggregateLoad(zeroElem.Interface(), composed)
} else { } else {
result, err = createSimpleLoad(db, zeroElem.Interface(), composed) result, err = d.createSimpleLoad(zeroElem.Interface(), composed)
} }
if err != nil { if err != nil {
return fmt.Errorf("can't create find result: %w", err) return fmt.Errorf("can't create find result: %w", err)
} }
defer result.Close(db.Context()) defer result.Close(d.Context())
var i int var i int
for i = 0; result.Next(db.Context()); { for i = 0; result.Next(d.Context()); {
if targetSliceV.Len() == i { if targetSliceV.Len() == i {
elem := reflect.New(targetSliceElemT.Elem()) elem := reflect.New(targetSliceElemT.Elem())
if err = result.Decode(elem.Interface()); err == nil { if err = result.Decode(elem.Interface()); err == nil {

@ -1,17 +1,16 @@
package common package database
import ( import (
"fmt" "fmt"
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/base"
"github.com/mainnika/mongox-go-driver/v2/mongox/query" "github.com/mainnika/mongox-go-driver/v2/mongox/query"
) )
// LoadOne function loads a first single target document by a query // LoadOne function loads a first single target document by a query
func LoadOne(db mongox.Database, target interface{}, filters ...interface{}) error { func (d *Database) LoadOne(target interface{}, filters ...interface{}) error {
composed := query.Compose(append(filters, query.Limit(1))...) composed := query.Compose(append(filters, query.Limit(1))...)
hasPreloader, _ := composed.Preloader() hasPreloader, _ := composed.Preloader()
@ -20,15 +19,15 @@ func LoadOne(db mongox.Database, target interface{}, filters ...interface{}) err
var err error var err error
if hasPreloader { if hasPreloader {
result, err = createAggregateLoad(db, target, composed) result, err = d.createAggregateLoad(target, composed)
} else { } else {
result, err = createSimpleLoad(db, target, composed) result, err = d.createSimpleLoad(target, composed)
} }
if err != nil { if err != nil {
return fmt.Errorf("can't create find result: %w", err) return fmt.Errorf("can't create find result: %w", err)
} }
hasNext := result.Next(db.Context()) hasNext := result.Next(d.Context())
if result.Err() != nil { if result.Err() != nil {
return err return err
} }

@ -0,0 +1,33 @@
package database
import (
"fmt"
"go.mongodb.org/mongo-driver/mongo"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
)
// LoadStream function loads documents one by one into a target channel
func (d *Database) LoadStream(target interface{}, filters ...interface{}) (mongox.StreamLoader, error) {
var cursor *mongo.Cursor
var err error
composed := query.Compose(filters...)
hasPreloader, _ := composed.Preloader()
if hasPreloader {
cursor, err = d.createAggregateLoad(target, composed)
} else {
cursor, err = d.createSimpleLoad(target, composed)
}
if err != nil {
return nil, fmt.Errorf("can't create find result: %w", err)
}
l := &StreamLoader{Cursor: cursor, ctx: d.Context(), target: target}
return l, nil
}

@ -1,4 +1,4 @@
package common package database
import ( import (
"time" "time"
@ -7,15 +7,14 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/base"
"github.com/mainnika/mongox-go-driver/v2/mongox/query" "github.com/mainnika/mongox-go-driver/v2/mongox/query"
) )
// SaveOne saves a single source document to the database // SaveOne saves a single source document to the database
func SaveOne(db mongox.Database, source interface{}) error { func (d *Database) SaveOne(source interface{}) error {
collection := db.GetCollectionOf(source) collection := d.GetCollectionOf(source)
opts := options.FindOneAndReplace() opts := options.FindOneAndReplace()
id := base.GetID(source) id := base.GetID(source)
protected := base.GetProtection(source) protected := base.GetProtection(source)
@ -30,7 +29,7 @@ func SaveOne(db mongox.Database, source interface{}) error {
protected.V = time.Now().Unix() protected.V = time.Now().Unix()
} }
result := collection.FindOneAndReplace(db.Context(), composed.M(), source, opts) result := collection.FindOneAndReplace(d.Context(), composed.M(), source, opts)
if result.Err() != nil { if result.Err() != nil {
return result.Err() return result.Err()
} }

@ -1,4 +1,4 @@
package common package database
import ( import (
"context" "context"
@ -6,9 +6,7 @@ import (
"go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/base"
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
) )
// StreamLoader is a controller for a database cursor // StreamLoader is a controller for a database cursor
@ -32,7 +30,19 @@ func (l *StreamLoader) DecodeNext() error {
base.Reset(l.target) base.Reset(l.target)
err := l.Decode(l.target) err := l.Cursor.Decode(l.target)
if err != nil {
return fmt.Errorf("can't decode desult: %w", err)
}
return nil
}
func (l *StreamLoader) Decode() error {
base.Reset(l.target)
err := l.Cursor.Decode(l.target)
if err != nil { if err != nil {
return fmt.Errorf("can't decode desult: %w", err) return fmt.Errorf("can't decode desult: %w", err)
} }
@ -60,26 +70,3 @@ func (l *StreamLoader) Close() error {
return l.Cursor.Close(l.ctx) return l.Cursor.Close(l.ctx)
} }
// LoadStream function loads documents one by one into a target channel
func LoadStream(db mongox.Database, target interface{}, filters ...interface{}) (*StreamLoader, error) {
var cursor *mongo.Cursor
var err error
composed := query.Compose(filters...)
hasPreloader, _ := composed.Preloader()
if hasPreloader {
cursor, err = createAggregateLoad(db, target, composed)
} else {
cursor, err = createSimpleLoad(db, target, composed)
}
if err != nil {
return nil, fmt.Errorf("can't create find result: %w", err)
}
l := &StreamLoader{Cursor: cursor, ctx: db.Context(), target: target}
return l, nil
}

@ -16,6 +16,22 @@ type Database interface {
Name() string Name() string
New(ctx context.Context) Database New(ctx context.Context) Database
GetCollectionOf(document interface{}) MongoCollection GetCollectionOf(document interface{}) MongoCollection
Count(target interface{}, filters ...interface{}) (int64, error)
DeleteArray(target interface{}) error
DeleteOne(target interface{}, filters ...interface{}) error
LoadArray(target interface{}, filters ...interface{}) error
LoadOne(target interface{}, filters ...interface{}) error
LoadStream(target interface{}, filters ...interface{}) (StreamLoader, error)
SaveOne(source interface{}) error
}
// StreamLoader is a interface to control database cursor
type StreamLoader interface {
DecodeNext() error
Decode() error
Next() error
Close() error
Err() error
} }
// MongoClient is the mongo client interface // MongoClient is the mongo client interface

Loading…
Cancel
Save