13 Commits

Author SHA1 Message Date
Nikita Tokarchuk 08c3c5b377 Add callback mechanism and implement on-decode callback 2020-07-13 16:32:48 +02:00
Nikita Tokarchuk 09fa64ab0e Return err if cannot decode array element 2020-07-13 04:09:04 +02:00
Nikita Tokarchuk c019a0ea4b Ignore unused arguments 2020-07-13 04:09:04 +02:00
Nikita Tokarchuk ee1b0e17d5 Improve panic messages 2020-07-13 04:09:04 +02:00
Nikita Tokarchuk 22a1d7033f Remove custom errors 2020-07-13 04:09:04 +02:00
Nikita Tokarchuk 1d3e29fe10 Use err type for panics 2020-07-13 04:09:04 +02:00
Nikita Tokarchuk 72e74a65b6 Use named returns for the code style consistency 2020-07-13 04:09:04 +02:00
Nikita Tokarchuk 9cf3551c20 MIT licensed 2020-06-04 21:30:19 +02:00
Nikita Tokarchuk 3035d8d571 Fix aggregation pipeline match step 2020-06-04 18:15:59 +02:00
Nikita Tokarchuk eac50d1770 Do not use unnecessary reflect 2020-06-04 18:15:35 +02:00
Nikita Tokarchuk 05ebb25e70 Use unsafe pointer in the interface struct header is more correct way 2020-06-04 18:15:08 +02:00
Nikita Tokarchuk fd53c66690 Use ordered document for index model 2020-03-25 17:40:26 +01:00
Nikita Tokarchuk 6111341a3c Check for nil interface correctly 2020-03-24 21:31:29 +01:00
29 changed files with 320 additions and 190 deletions
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright © 2020 Nikita Tokarchuk
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+1 -1
View File
@@ -30,6 +30,6 @@ func NewEphemeral(URI string) (db *EphemeralDatabase, err error) {
} }
// Close the connection and drop database // Close the connection and drop database
func (e *EphemeralDatabase) Close() error { func (e *EphemeralDatabase) Close() (err error) {
return e.Client().Database(e.Name()).Drop(e.Context()) return e.Client().Database(e.Name()).Drop(e.Context())
} }
+2 -1
View File
@@ -6,6 +6,7 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/bson/primitive"
"github.com/mainnika/mongox-go-driver/v2/mongox" "github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/utils"
) )
// GetID returns source document id // GetID returns source document id
@@ -62,7 +63,7 @@ func getObjectOrPanic(source mongox.JSONBased) (id primitive.D) {
func getInterfaceOrPanic(source mongox.InterfaceBased) (id interface{}) { func getInterfaceOrPanic(source mongox.InterfaceBased) (id interface{}) {
id = source.GetID() id = source.GetID()
if id != nil { if !utils.IsNil(id) {
return id return id
} }
+7 -5
View File
@@ -7,11 +7,11 @@ import (
) )
// GetProtection function finds protection field in the source document otherwise returns nil // GetProtection function finds protection field in the source document otherwise returns nil
func GetProtection(source interface{}) *protection.Key { func GetProtection(source interface{}) (key *protection.Key) {
v := reflect.ValueOf(source) v := reflect.ValueOf(source)
if v.Kind() != reflect.Ptr || v.IsNil() { if v.Kind() != reflect.Ptr || v.IsNil() {
return nil return
} }
el := v.Elem() el := v.Elem()
@@ -25,14 +25,16 @@ func GetProtection(source interface{}) *protection.Key {
switch field.Interface().(type) { switch field.Interface().(type) {
case *protection.Key: case *protection.Key:
return field.Interface().(*protection.Key) key = field.Interface().(*protection.Key)
case protection.Key: case protection.Key:
ptr := field.Addr() ptr := field.Addr()
return ptr.Interface().(*protection.Key) key = ptr.Interface().(*protection.Key)
default: default:
continue continue
} }
return
} }
return nil return
} }
+1 -1
View File
@@ -14,7 +14,7 @@ type Primary struct {
} }
// GetID returns an _id // GetID returns an _id
func (p *Primary) GetID() primitive.D { func (p *Primary) GetID() (id primitive.D) {
return p.ID return p.ID
} }
+1 -1
View File
@@ -14,7 +14,7 @@ type Primary struct {
} }
// GetID returns an _id // GetID returns an _id
func (p *Primary) GetID() primitive.ObjectID { func (p *Primary) GetID() (id primitive.ObjectID) {
return p.ID return p.ID
} }
+2 -1
View File
@@ -1,6 +1,7 @@
package base package base
import ( import (
"fmt"
"reflect" "reflect"
) )
@@ -19,7 +20,7 @@ func Reset(target interface{}) {
v := reflect.ValueOf(target) v := reflect.ValueOf(target)
if v.Kind() != reflect.Ptr { if v.Kind() != reflect.Ptr {
panic("reset target should be a pointer") panic(fmt.Errorf("reset target should be a pointer"))
} }
t := v.Elem().Type() t := v.Elem().Type()
+1 -1
View File
@@ -12,7 +12,7 @@ type Primary struct {
} }
// GetID returns an _id // GetID returns an _id
func (p *Primary) GetID() string { func (p *Primary) GetID() (id string) {
return p.ID return p.ID
} }
+19
View File
@@ -0,0 +1,19 @@
package database
import (
"context"
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
)
func onDecode(ctx context.Context, iter interface{}, callbacks ...query.OnDecode) (err error) {
for _, cb := range callbacks {
err = cb(ctx, iter)
if err != nil {
return
}
}
return
}
+3 -12
View File
@@ -1,17 +1,14 @@
package database package database
import ( import (
"fmt"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/query" "github.com/mainnika/mongox-go-driver/v2/mongox/query"
) )
// Count function counts documents in the database by query // Count function counts documents in the database by query
// target is used only to get collection by tag so it'd be better to use nil ptr here // target is used only to get collection by tag so it'd be better to use nil ptr here
func (d *Database) Count(target interface{}, filters ...interface{}) (int64, error) { func (d *Database) Count(target interface{}, filters ...interface{}) (result int64, err error) {
collection := d.GetCollectionOf(target) collection := d.GetCollectionOf(target)
opts := options.Count() opts := options.Count()
@@ -20,13 +17,7 @@ func (d *Database) Count(target interface{}, filters ...interface{}) (int64, err
opts.Limit = composed.Limiter() opts.Limit = composed.Limiter()
opts.Skip = composed.Skipper() opts.Skip = composed.Skipper()
result, err := collection.CountDocuments(d.Context(), composed.M(), opts) result, err = collection.CountDocuments(d.Context(), composed.M(), opts)
if err == mongox.ErrNoDocuments {
return 0, err
}
if err != nil {
return 0, fmt.Errorf("can't decode desult: %w", err)
}
return result, nil return
} }
+19 -16
View File
@@ -22,17 +22,18 @@ type Database struct {
} }
// NewDatabase function creates new database instance with mongo client and empty context // NewDatabase function creates new database instance with mongo client and empty context
func NewDatabase(client *mongox.Client, dbname string) mongox.Database { func NewDatabase(client *mongox.Client, dbname string) (db mongox.Database) {
db := &Database{} db = &Database{
db.client = client client: client,
db.dbname = dbname dbname: dbname,
}
return db return
} }
// Client function returns a mongo client // Client function returns a mongo client
func (d *Database) Client() *mongox.Client { func (d *Database) Client() (client *mongox.Client) {
return d.client return d.client
} }
@@ -48,22 +49,24 @@ func (d *Database) Context() (ctx context.Context) {
} }
// Name function returns a database name // Name function returns a database name
func (d *Database) Name() string { func (d *Database) Name() (name string) {
return d.dbname return d.dbname
} }
// New function creates new database context with same client // New function creates new database context with same client
func (d *Database) New(ctx context.Context) mongox.Database { func (d *Database) New(ctx context.Context) (db mongox.Database) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
return &Database{ db = &Database{
client: d.client, client: d.client,
dbname: d.dbname, dbname: d.dbname,
ctx: ctx, ctx: ctx,
} }
return
} }
// GetCollectionOf returns the collection object by the «collection» tag of the given document; // GetCollectionOf returns the collection object by the «collection» tag of the given document;
@@ -72,7 +75,7 @@ func (d *Database) New(ctx context.Context) mongox.Database {
// base.ObjectID `bson:",inline" json:",inline" collection:"foobars"` // base.ObjectID `bson:",inline" json:",inline" collection:"foobars"`
// ... // ...
// Will panic if there is no «collection» tag // Will panic if there is no «collection» tag
func (d *Database) GetCollectionOf(document interface{}) *mongox.Collection { func (d *Database) GetCollectionOf(document interface{}) (collection *mongox.Collection) {
el := reflect.TypeOf(document).Elem() el := reflect.TypeOf(document).Elem()
numField := el.NumField() numField := el.NumField()
@@ -111,7 +114,7 @@ func (d *Database) createAggregateLoad(target interface{}, composed *query.Query
pipeline := primitive.A{} pipeline := primitive.A{}
if !composed.Empty() { if !composed.Empty() {
pipeline = append(pipeline, primitive.M{"$match": primitive.M{"$expr": composed.M()}}) pipeline = append(pipeline, primitive.M{"$match": composed.M()})
} }
if composed.Sorter() != nil { if composed.Sorter() != nil {
pipeline = append(pipeline, primitive.M{"$sort": composed.Sorter()}) pipeline = append(pipeline, primitive.M{"$sort": composed.Sorter()})
@@ -137,9 +140,9 @@ func (d *Database) createAggregateLoad(target interface{}, composed *query.Query
if !ok { if !ok {
continue continue
} }
jsonTag, ok := tag.Lookup("json") jsonTag, _ := tag.Lookup("json")
if jsonTag == "-" { if jsonTag == "-" {
return nil, fmt.Errorf("preload private field is impossible") panic(fmt.Errorf("preload private field is impossible"))
} }
jsonData := strings.SplitN(jsonTag, ",", 2) jsonData := strings.SplitN(jsonTag, ",", 2)
@@ -153,7 +156,7 @@ func (d *Database) createAggregateLoad(target interface{}, composed *query.Query
continue continue
} }
if len(preloadData) == 1 { if len(preloadData) == 1 {
panic("there is no foreign field") panic(fmt.Errorf("there is no foreign field"))
} }
localField := strings.TrimSpace(preloadData[0]) localField := strings.TrimSpace(preloadData[0])
@@ -163,7 +166,7 @@ func (d *Database) createAggregateLoad(target interface{}, composed *query.Query
foreignField := strings.TrimSpace(preloadData[1]) foreignField := strings.TrimSpace(preloadData[1])
if len(foreignField) == 0 { if len(foreignField) == 0 {
panic("there is no foreign field") panic(fmt.Errorf("there is no foreign field"))
} }
preloadLimiter := 100 preloadLimiter := 100
@@ -195,7 +198,7 @@ func (d *Database) createAggregateLoad(target interface{}, composed *query.Query
typ = typ.Elem() typ = typ.Elem()
} }
if typ.Kind() != reflect.Ptr { if typ.Kind() != reflect.Ptr {
panic("preload field should have ptr type") panic(fmt.Errorf("preload field should have ptr type"))
} }
lookupCollection := d.GetCollectionOf(reflect.Zero(typ).Interface()) lookupCollection := d.GetCollectionOf(reflect.Zero(typ).Interface())
+4 -4
View File
@@ -11,7 +11,7 @@ import (
) )
// DeleteArray removes documents list from a database by their ids // DeleteArray removes documents list from a database by their ids
func (d *Database) DeleteArray(target interface{}) error { func (d *Database) DeleteArray(target interface{}) (err error) {
targetV := reflect.ValueOf(target) targetV := reflect.ValueOf(target)
targetT := targetV.Type() targetT := targetV.Type()
@@ -49,11 +49,11 @@ func (d *Database) DeleteArray(target interface{}) error {
result, err := collection.DeleteMany(d.Context(), primitive.M{"_id": primitive.M{"$in": ids}}, opts) result, err := collection.DeleteMany(d.Context(), primitive.M{"_id": primitive.M{"$in": ids}}, opts)
if err != nil { if err != nil {
return fmt.Errorf("can't create find and delete result: %w", err) return
} }
if result.DeletedCount != int64(targetLen) { if result.DeletedCount != int64(targetLen) {
return fmt.Errorf("can't verify delete result: removed count mismatch %d != %d", result.DeletedCount, targetLen) err = fmt.Errorf("can't verify delete result: removed count mismatch %d != %d", result.DeletedCount, targetLen)
} }
return nil return
} }
+5 -11
View File
@@ -7,13 +7,13 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/options"
"github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/base"
"github.com/mainnika/mongox-go-driver/v2/mongox/query" "github.com/mainnika/mongox-go-driver/v2/mongox/query"
"github.com/mainnika/mongox-go-driver/v2/mongox/utils"
) )
// DeleteOne removes a document from a database and then returns it into target // DeleteOne removes a document from a database and then returns it into target
func (d *Database) DeleteOne(target interface{}, filters ...interface{}) error { func (d *Database) DeleteOne(target interface{}, filters ...interface{}) (err error) {
collection := d.GetCollectionOf(target) collection := d.GetCollectionOf(target)
opts := &options.FindOneAndDeleteOptions{} opts := &options.FindOneAndDeleteOptions{}
@@ -22,7 +22,7 @@ func (d *Database) DeleteOne(target interface{}, filters ...interface{}) error {
opts.Sort = composed.Sorter() opts.Sort = composed.Sorter()
if target != nil { if !utils.IsNil(target) {
composed.And(primitive.M{"_id": base.GetID(target)}) composed.And(primitive.M{"_id": base.GetID(target)})
} }
@@ -37,13 +37,7 @@ func (d *Database) DeleteOne(target interface{}, filters ...interface{}) error {
return fmt.Errorf("can't create find one and delete result: %w", result.Err()) return fmt.Errorf("can't create find one and delete result: %w", result.Err())
} }
err := result.Decode(target) err = result.Decode(target)
if err == mongox.ErrNoDocuments {
return err
}
if err != nil {
return fmt.Errorf("can't decode result: %w", err)
}
return nil return
} }
+13 -11
View File
@@ -21,7 +21,7 @@ import (
// `index:"-,+foo,+-bar,unique,allowNull"` -- https://docs.mongodb.com/manual/core/index-partial // `index:"-,+foo,+-bar,unique,allowNull"` -- https://docs.mongodb.com/manual/core/index-partial
// `index:"-,unique,allowNull,expireAfter=86400"` -- https://docs.mongodb.com/manual/core/index-ttl // `index:"-,unique,allowNull,expireAfter=86400"` -- https://docs.mongodb.com/manual/core/index-ttl
// `index:"-,unique,allowNull,expireAfter={{.Expire}}"` -- evaluate index as a golang template with `cfg` arguments // `index:"-,unique,allowNull,expireAfter={{.Expire}}"` -- evaluate index as a golang template with `cfg` arguments
func (d *Database) IndexEnsure(cfg interface{}, document interface{}) error { func (d *Database) IndexEnsure(cfg interface{}, document interface{}) (err error) {
el := reflect.ValueOf(document).Elem().Type() el := reflect.ValueOf(document).Elem().Type()
numField := el.NumField() numField := el.NumField()
@@ -41,14 +41,16 @@ func (d *Database) IndexEnsure(cfg interface{}, document interface{}) error {
return fmt.Errorf("bson tag is not defined for field:%v document:%v", field, document) return fmt.Errorf("bson tag is not defined for field:%v document:%v", field, document)
} }
tmpBuffer := &bytes.Buffer{} var tmpBuffer = &bytes.Buffer{}
tpl, err := template.New("").Parse(indexTag) var tpl *template.Template
tpl, err = template.New("").Parse(indexTag)
if err != nil { if err != nil {
panic(fmt.Errorf("invalid prop template, %v", indexTag)) panic(fmt.Errorf("invalid prop template %v, err:%w", indexTag, err))
} }
err = tpl.Execute(tmpBuffer, cfg) err = tpl.Execute(tmpBuffer, cfg)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to evaluate prop template, %v", indexTag)) panic(fmt.Errorf("failed to evaluate prop template %v, err:%w", indexTag, err))
} }
indexString := tmpBuffer.String() indexString := tmpBuffer.String()
@@ -64,15 +66,15 @@ func (d *Database) IndexEnsure(cfg interface{}, document interface{}) error {
panic(fmt.Errorf("cannot evaluate index key")) panic(fmt.Errorf("cannot evaluate index key"))
} }
index := primitive.M{key: 1}
opts := &options.IndexOptions{ opts := &options.IndexOptions{
Background: &f, Background: &f,
Unique: &f, Unique: &f,
Name: &name, Name: &name,
} }
index := primitive.D{{Key: key, Value: 1}}
if indexValues[0] == "-" { if indexValues[0] == "-" {
index[key] = -1 index = primitive.D{{Key: key, Value: -1}}
} }
for _, prop := range indexValues[1:] { for _, prop := range indexValues[1:] {
@@ -114,9 +116,9 @@ func (d *Database) IndexEnsure(cfg interface{}, document interface{}) error {
} }
if compoundValue[0] == '-' { if compoundValue[0] == '-' {
index[compoundValue[1:]] = -1 index = append(index, primitive.E{compoundValue[1:], -1})
} else { } else {
index[compoundValue] = 1 index = append(index, primitive.E{compoundValue, 1})
} }
default: default:
@@ -126,9 +128,9 @@ func (d *Database) IndexEnsure(cfg interface{}, document interface{}) error {
_, err = documents.Indexes().CreateOne(d.Context(), mongo.IndexModel{Keys: index, Options: opts}) _, err = documents.Indexes().CreateOne(d.Context(), mongo.IndexModel{Keys: index, Options: opts})
if err != nil { if err != nil {
return err return
} }
} }
return nil return
} }
+23 -13
View File
@@ -10,7 +10,7 @@ import (
) )
// LoadArray loads an array of documents from the database by query // LoadArray loads an array of documents from the database by query
func (d *Database) LoadArray(target interface{}, filters ...interface{}) error { func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err error) {
targetV := reflect.ValueOf(target) targetV := reflect.ValueOf(target)
targetT := targetV.Type() targetT := targetV.Type()
@@ -36,7 +36,7 @@ func (d *Database) LoadArray(target interface{}, filters ...interface{}) error {
hasPreloader, _ := composed.Preloader() hasPreloader, _ := composed.Preloader()
var result *mongox.Cursor var result *mongox.Cursor
var err error var i int
if hasPreloader { if hasPreloader {
result, err = d.createAggregateLoad(zeroElem.Interface(), composed) result, err = d.createAggregateLoad(zeroElem.Interface(), composed)
@@ -44,25 +44,35 @@ func (d *Database) LoadArray(target interface{}, filters ...interface{}) error {
result, err = d.createSimpleLoad(zeroElem.Interface(), composed) result, err = d.createSimpleLoad(zeroElem.Interface(), composed)
} }
if err != nil { if err != nil {
return fmt.Errorf("can't create find result: %w", err) err = fmt.Errorf("can't create find result: %w", err)
return
} }
var i int
for i = 0; result.Next(d.Context()); { for i = 0; result.Next(d.Context()); {
var elem interface{}
if targetSliceV.Len() == i { if targetSliceV.Len() == i {
elem := reflect.New(targetSliceElemT.Elem()) value := reflect.New(targetSliceElemT.Elem())
if err = result.Decode(elem.Interface()); err == nil { err = result.Decode(value.Interface())
targetSliceV = reflect.Append(targetSliceV, elem) elem = value.Interface()
} else { if err == nil {
continue targetSliceV = reflect.Append(targetSliceV, value)
} }
} else { } else {
elem := targetSliceV.Index(i).Interface() elem = targetSliceV.Index(i).Interface()
base.Reset(elem) base.Reset(elem)
if err = result.Decode(elem); err != nil { err = result.Decode(elem)
continue
} }
if err != nil {
_ = result.Close(d.Context())
return
}
err = onDecode(d.ctx, elem, composed.OnDecode()...)
if err != nil {
_ = result.Close(d.Context())
return
} }
i++ i++
+12 -3
View File
@@ -9,13 +9,12 @@ import (
) )
// LoadOne function loads a first single target document by a query // LoadOne function loads a first single target document by a query
func (d *Database) LoadOne(target interface{}, filters ...interface{}) error { func (d *Database) LoadOne(target interface{}, filters ...interface{}) (err error) {
composed := query.Compose(append(filters, query.Limit(1))...) composed := query.Compose(append(filters, query.Limit(1))...)
hasPreloader, _ := composed.Preloader() hasPreloader, _ := composed.Preloader()
var result *mongox.Cursor var result *mongox.Cursor
var err error
if hasPreloader { if hasPreloader {
result, err = d.createAggregateLoad(target, composed) result, err = d.createAggregateLoad(target, composed)
@@ -36,5 +35,15 @@ func (d *Database) LoadOne(target interface{}, filters ...interface{}) error {
base.Reset(target) base.Reset(target)
return result.Decode(target) err = result.Decode(target)
if err != nil {
return
}
err = onDecode(d.ctx, target, composed.OnDecode()...)
if err != nil {
return
}
return
} }
+5 -5
View File
@@ -8,10 +8,9 @@ import (
) )
// LoadStream function loads documents one by one into a target channel // LoadStream function loads documents one by one into a target channel
func (d *Database) LoadStream(target interface{}, filters ...interface{}) (mongox.StreamLoader, error) { func (d *Database) LoadStream(target interface{}, filters ...interface{}) (loader mongox.StreamLoader, err error) {
var cursor *mongox.Cursor var cursor *mongox.Cursor
var err error
composed := query.Compose(filters...) composed := query.Compose(filters...)
hasPreloader, _ := composed.Preloader() hasPreloader, _ := composed.Preloader()
@@ -22,10 +21,11 @@ func (d *Database) LoadStream(target interface{}, filters ...interface{}) (mongo
cursor, err = d.createSimpleLoad(target, composed) cursor, err = d.createSimpleLoad(target, composed)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("can't create find result: %w", err) err = fmt.Errorf("can't create find result: %w", err)
return
} }
l := &StreamLoader{cur: cursor, ctx: d.Context(), target: target} loader = &StreamLoader{cur: cursor, ctx: d.Context(), target: target, query: composed}
return l, nil return
} }
+1 -1
View File
@@ -12,7 +12,7 @@ import (
) )
// SaveOne saves a single source document to the database // SaveOne saves a single source document to the database
func (d *Database) SaveOne(source interface{}) error { func (d *Database) SaveOne(source interface{}) (err error) {
collection := d.GetCollectionOf(source) collection := d.GetCollectionOf(source)
opts := options.FindOneAndReplace() opts := options.FindOneAndReplace()
+29 -28
View File
@@ -2,78 +2,79 @@ package database
import ( import (
"context" "context"
"fmt"
"github.com/mainnika/mongox-go-driver/v2/mongox" "github.com/mainnika/mongox-go-driver/v2/mongox"
"github.com/mainnika/mongox-go-driver/v2/mongox/base" "github.com/mainnika/mongox-go-driver/v2/mongox/base"
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
) )
// StreamLoader is a controller for a database cursor // StreamLoader is a controller for a database cursor
type StreamLoader struct { type StreamLoader struct {
cur *mongox.Cursor cur *mongox.Cursor
query *query.Query
ctx context.Context ctx context.Context
target interface{} target interface{}
} }
// DecodeNext loads next documents to a target or returns an error // DecodeNext loads next documents to a target or returns an error
func (l *StreamLoader) DecodeNext() error { func (l *StreamLoader) DecodeNext() (err error) {
hasNext := l.cur.Next(l.ctx) err = l.Next()
if l.cur.Err() != nil {
return l.cur.Err()
}
if !hasNext {
return mongox.ErrNoDocuments
}
base.Reset(l.target)
err := l.cur.Decode(l.target)
if err != nil { if err != nil {
return fmt.Errorf("can't decode desult: %w", err) return
} }
return nil err = l.Decode()
if err != nil {
return
}
return
} }
// Decode function decodes the current cursor document into the target // Decode function decodes the current cursor document into the target
func (l *StreamLoader) Decode() error { func (l *StreamLoader) Decode() (err error) {
base.Reset(l.target) base.Reset(l.target)
err := l.cur.Decode(l.target) err = l.cur.Decode(l.target)
if err != nil { if err != nil {
return fmt.Errorf("can't decode desult: %w", err) return
} }
return nil err = onDecode(l.ctx, l.target, l.query.OnDecode()...)
if err != nil {
return
}
return
} }
// Next loads next documents but doesn't perform decoding // Next loads next documents but doesn't perform decoding
func (l *StreamLoader) Next() error { func (l *StreamLoader) Next() (err error) {
hasNext := l.cur.Next(l.ctx) hasNext := l.cur.Next(l.ctx)
err = l.cur.Err()
if l.cur.Err() != nil { if err != nil {
return l.cur.Err() return
} }
if !hasNext { if !hasNext {
return mongox.ErrNoDocuments err = mongox.ErrNoDocuments
} }
return nil return
} }
func (l *StreamLoader) Cursor() *mongox.Cursor { func (l *StreamLoader) Cursor() (cursor *mongox.Cursor) {
return l.cur return l.cur
} }
// Close cursor // Close cursor
func (l *StreamLoader) Close() error { func (l *StreamLoader) Close() (err error) {
return l.cur.Close(l.ctx) return l.cur.Close(l.ctx)
} }
func (l *StreamLoader) Err() error { func (l *StreamLoader) Err() (err error) {
return l.cur.Err() return l.cur.Err()
} }
+23 -23
View File
@@ -16,51 +16,51 @@ type (
// Database is the mongox database interface // Database is the mongox database interface
type Database interface { type Database interface {
Client() *Client Client() (client *Client)
Context() context.Context Context() (context context.Context)
Name() string Name() (name string)
New(ctx context.Context) Database New(ctx context.Context) (db Database)
GetCollectionOf(document interface{}) *Collection GetCollectionOf(document interface{}) (collection *Collection)
Count(target interface{}, filters ...interface{}) (int64, error) Count(target interface{}, filters ...interface{}) (count int64, err error)
DeleteArray(target interface{}) error DeleteArray(target interface{}) (err error)
DeleteOne(target interface{}, filters ...interface{}) error DeleteOne(target interface{}, filters ...interface{}) (err error)
LoadArray(target interface{}, filters ...interface{}) error LoadArray(target interface{}, filters ...interface{}) (err error)
LoadOne(target interface{}, filters ...interface{}) error LoadOne(target interface{}, filters ...interface{}) (err error)
LoadStream(target interface{}, filters ...interface{}) (StreamLoader, error) LoadStream(target interface{}, filters ...interface{}) (loader StreamLoader, err error)
SaveOne(source interface{}) error SaveOne(source interface{}) (err error)
IndexEnsure(cfg interface{}, document interface{}) error IndexEnsure(cfg interface{}, document interface{}) (err error)
} }
// StreamLoader is a interface to control database cursor // StreamLoader is a interface to control database cursor
type StreamLoader interface { type StreamLoader interface {
Cursor() *Cursor Cursor() (cursor *Cursor)
DecodeNext() error DecodeNext() (err error)
Decode() error Decode() (err error)
Next() error Next() (err error)
Close() error Close() (err error)
Err() error Err() (err error)
} }
// OIDBased is an interface for documents that have objectId type for the _id field // OIDBased is an interface for documents that have objectId type for the _id field
type OIDBased interface { type OIDBased interface {
GetID() primitive.ObjectID GetID() (id primitive.ObjectID)
SetID(id primitive.ObjectID) SetID(id primitive.ObjectID)
} }
// StringBased is an interface for documents that have string type for the _id field // StringBased is an interface for documents that have string type for the _id field
type StringBased interface { type StringBased interface {
GetID() string GetID() (id string)
SetID(id string) SetID(id string)
} }
// JSONBased is an interface for documents that have object type for the _id field // JSONBased is an interface for documents that have object type for the _id field
type JSONBased interface { type JSONBased interface {
GetID() primitive.D GetID() (id primitive.D)
SetID(id primitive.D) SetID(id primitive.D)
} }
// InterfaceBased is an interface for documents that have custom declated type for the _id field // InterfaceBased is an interface for documents that have custom declated type for the _id field
type InterfaceBased interface { type InterfaceBased interface {
GetID() interface{} GetID() (id interface{})
SetID(id interface{}) SetID(id interface{})
} }
+7
View File
@@ -0,0 +1,7 @@
package query
import (
"context"
)
type OnDecode func(ctx context.Context, iter interface{}) (err error)
+26 -13
View File
@@ -7,42 +7,44 @@ import (
"go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/bson/primitive"
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection" "github.com/mainnika/mongox-go-driver/v2/mongox/base/protection"
"github.com/mainnika/mongox-go-driver/v2/mongox/utils"
) )
// Compose is a function to compose filters into a single query // Compose is a function to compose filters into a single query
func Compose(filters ...interface{}) *Query { func Compose(filters ...interface{}) (query *Query) {
q := &Query{} query = &Query{}
for _, f := range filters { for _, f := range filters {
if !Push(q, f) { if !Push(query, f) {
panic(fmt.Errorf("unknown filter %v", f)) panic(fmt.Errorf("unknown filter %v", f))
} }
} }
return q return
} }
// Push applies single filter to a query // Push applies single filter to a query
func Push(q *Query, f interface{}) bool { func Push(q *Query, f interface{}) (ok bool) {
if f == nil { if utils.IsNil(f) {
return true return true
} }
ok := false ok = false
ok = ok || applyBson(q, f) ok = ok || applyBson(q, f)
ok = ok || applyLimit(q, f) ok = ok || applyLimit(q, f)
ok = ok || applySort(q, f) ok = ok || applySort(q, f)
ok = ok || applySkip(q, f) ok = ok || applySkip(q, f)
ok = ok || applyProtection(q, f) ok = ok || applyProtection(q, f)
ok = ok || applyPreloader(q, f) ok = ok || applyPreloader(q, f)
ok = ok || applyCallbacks(q, f)
return ok return ok
} }
// applyBson is a fallback for a custom bson.M // applyBson is a fallback for a custom bson.M
func applyBson(q *Query, f interface{}) bool { func applyBson(q *Query, f interface{}) (ok bool) {
if f, ok := f.(bson.M); ok { if f, ok := f.(bson.M); ok {
q.And(f) q.And(f)
@@ -53,7 +55,7 @@ func applyBson(q *Query, f interface{}) bool {
} }
// applyLimits extends query with a limiter // applyLimits extends query with a limiter
func applyLimit(q *Query, f interface{}) bool { func applyLimit(q *Query, f interface{}) (ok bool) {
if f, ok := f.(Limiter); ok { if f, ok := f.(Limiter); ok {
q.limiter = f q.limiter = f
@@ -64,7 +66,7 @@ func applyLimit(q *Query, f interface{}) bool {
} }
// applySort extends query with a sort rule // applySort extends query with a sort rule
func applySort(q *Query, f interface{}) bool { func applySort(q *Query, f interface{}) (ok bool) {
if f, ok := f.(Sorter); ok { if f, ok := f.(Sorter); ok {
q.sorter = f q.sorter = f
@@ -75,7 +77,7 @@ func applySort(q *Query, f interface{}) bool {
} }
// applySkip extends query with a skip number // applySkip extends query with a skip number
func applySkip(q *Query, f interface{}) bool { func applySkip(q *Query, f interface{}) (ok bool) {
if f, ok := f.(Skipper); ok { if f, ok := f.(Skipper); ok {
q.skipper = f q.skipper = f
@@ -85,7 +87,7 @@ func applySkip(q *Query, f interface{}) bool {
return false return false
} }
func applyProtection(q *Query, f interface{}) bool { func applyProtection(q *Query, f interface{}) (ok bool) {
var x *primitive.ObjectID var x *primitive.ObjectID
var v *int64 var v *int64
@@ -113,7 +115,7 @@ func applyProtection(q *Query, f interface{}) bool {
return true return true
} }
func applyPreloader(q *Query, f interface{}) bool { func applyPreloader(q *Query, f interface{}) (ok bool) {
if f, ok := f.(Preloader); ok { if f, ok := f.(Preloader); ok {
q.preloader = f q.preloader = f
@@ -122,3 +124,14 @@ func applyPreloader(q *Query, f interface{}) bool {
return false return false
} }
func applyCallbacks(q *Query, f interface{}) (ok bool) {
switch cb := f.(type) {
case OnDecode:
q.ondecode = append(q.ondecode, cb)
ok = true
}
return
}
+8 -6
View File
@@ -2,7 +2,7 @@ package query
// Limiter is a filter to limit the result // Limiter is a filter to limit the result
type Limiter interface { type Limiter interface {
Limit() *int64 Limit() (limit *int64)
} }
// Limit is a simple implementation of the Limiter filter // Limit is a simple implementation of the Limiter filter
@@ -11,12 +11,14 @@ type Limit int64
var _ Limiter = Limit(0) var _ Limiter = Limit(0)
// Limit returns a limit // Limit returns a limit
func (l Limit) Limit() *int64 { func (l Limit) Limit() (limit *int64) {
lim := int64(l) if l <= 0 {
if lim <= 0 { return
return nil
} }
return &lim limit = new(int64)
*limit = int64(l)
return
} }
+2 -2
View File
@@ -2,7 +2,7 @@ package query
// Preloader is a filter to skip the result // Preloader is a filter to skip the result
type Preloader interface { type Preloader interface {
Preload() []string Preload() (preloads []string)
} }
// Preload is a simple implementation of the Skipper filter // Preload is a simple implementation of the Skipper filter
@@ -11,6 +11,6 @@ type Preload []string
var _ Preloader = Preload{} var _ Preloader = Preload{}
// Preload returns a preload list // Preload returns a preload list
func (l Preload) Preload() []string { func (l Preload) Preload() (preloads []string) {
return l return l
} }
+19 -22
View File
@@ -2,8 +2,6 @@ package query
import ( import (
"go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson"
"reflect"
) )
// Query is an enchanched bson.M map // Query is an enchanched bson.M map
@@ -13,10 +11,11 @@ type Query struct {
sorter Sorter sorter Sorter
skipper Skipper skipper Skipper
preloader Preloader preloader Preloader
ondecode []OnDecode
} }
// And function pushes the elem query to the $and array of the query // And function pushes the elem query to the $and array of the query
func (q *Query) And(elem bson.M) *Query { func (q *Query) And(elem bson.M) (query *Query) {
if q.m == nil { if q.m == nil {
q.m = bson.M{} q.m = bson.M{}
@@ -35,61 +34,59 @@ func (q *Query) And(elem bson.M) *Query {
} }
// Limiter returns limiter value or nil // Limiter returns limiter value or nil
func (q *Query) Limiter() *int64 { func (q *Query) Limiter() (limit *int64) {
if q.limiter == nil { if q.limiter == nil {
return nil return
} }
return q.limiter.Limit() return q.limiter.Limit()
} }
// Sorter is a sort rule for a query // Sorter is a sort rule for a query
func (q *Query) Sorter() interface{} { func (q *Query) Sorter() (sort interface{}) {
if q.sorter == nil { if q.sorter == nil {
return nil return
} }
return q.sorter.Sort() return q.sorter.Sort()
} }
// Skipper is a skipper for a query // Skipper is a skipper for a query
func (q *Query) Skipper() *int64 { func (q *Query) Skipper() (skip *int64) {
if q.skipper == nil { if q.skipper == nil {
return nil return
} }
return q.skipper.Skip() return q.skipper.Skip()
} }
// Preloader is a preloader list for a query // Preloader is a preloader list for a query
func (q *Query) Preloader() (empty bool, preloader []string) { func (q *Query) Preloader() (ok bool, preloads []string) {
if q.preloader == nil { if q.preloader == nil {
return false, nil return false, nil
} }
preloader = q.preloader.Preload() preloads = q.preloader.Preload()
ok = len(preloads) > 0
if len(preloader) == 0 { return
return false, nil }
}
return true, preloader // OnDecode callback is called after the mongo decode function
func (q *Query) OnDecode() (callbacks []OnDecode) {
return q.ondecode
} }
// Empty checks the query for any content // Empty checks the query for any content
func (q *Query) Empty() bool { func (q *Query) Empty() (isEmpty bool) {
return len(q.m) == 0
qv := reflect.ValueOf(q.m)
keys := qv.MapKeys()
return len(keys) == 0
} }
// M returns underlying query map // M returns underlying query map
func (q *Query) M() bson.M { func (q *Query) M() (m bson.M) {
return q.m return q.m
} }
+8 -6
View File
@@ -2,7 +2,7 @@ package query
// Skipper is a filter to skip the result // Skipper is a filter to skip the result
type Skipper interface { type Skipper interface {
Skip() *int64 Skip() (skip *int64)
} }
// Skip is a simple implementation of the Skipper filter // Skip is a simple implementation of the Skipper filter
@@ -11,12 +11,14 @@ type Skip int64
var _ Skipper = Skip(0) var _ Skipper = Skip(0)
// Skip returns a skip number // Skip returns a skip number
func (l Skip) Skip() *int64 { func (l Skip) Skip() (skip *int64) {
lim := int64(l) if l <= 0 {
if lim <= 0 { return
return nil
} }
return &lim skip = new(int64)
*skip = int64(l)
return
} }
+2 -2
View File
@@ -6,7 +6,7 @@ import (
// Sorter is a filter to sort the data before query // Sorter is a filter to sort the data before query
type Sorter interface { type Sorter interface {
Sort() bson.M Sort() (sort bson.M)
} }
// Sort is a simple implementations of the Sorter filter // Sort is a simple implementations of the Sorter filter
@@ -15,6 +15,6 @@ type Sort bson.M
var _ Sorter = &Sort{} var _ Sorter = &Sort{}
// Sort returns a slice of fields which have to be sorted // Sort returns a slice of fields which have to be sorted
func (f Sort) Sort() bson.M { func (f Sort) Sort() (sort bson.M) {
return bson.M(f) return bson.M(f)
} }
+23
View File
@@ -0,0 +1,23 @@
package utils
import (
"unsafe"
)
// IsNil function evaluates the interface value to nil
func IsNil(i interface{}) (isNil bool) {
type iface struct {
_ unsafe.Pointer
ptr unsafe.Pointer
}
unpacked := (*iface)(unsafe.Pointer(&i))
if unpacked.ptr == nil {
isNil = true
return
}
isNil = *(*unsafe.Pointer)(unpacked.ptr) == nil
return
}
+32
View File
@@ -0,0 +1,32 @@
package utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsNil(t *testing.T) {
testvalues := []struct {
i interface{}
isnil bool
}{
{nil, true},
{(*string)(nil), true},
{([]string)(nil), true},
{(map[string]string)(nil), true},
{(func() bool)(nil), true},
{(chan func() bool)(nil), true},
{"", true},
{0, true},
{append(([]string)(nil), ""), false},
{[]string{}, false},
{1, false},
{"1", false},
}
for _, tt := range testvalues {
assert.Equal(t, tt.isnil, IsNil(tt.i))
}
}