mirror of
https://github.com/mainnika/mongox-go-driver.git
synced 2026-05-22 15:53:36 +00:00
Prepare for v3
This commit is contained in:
@@ -0,0 +1,44 @@
|
||||
package docbased
|
||||
|
||||
import (
|
||||
"github.com/modern-go/reflect2"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox"
|
||||
)
|
||||
|
||||
var _ mongox.DocBased = (*Primary)(nil)
|
||||
|
||||
// Primary is a structure with object as an _id field
|
||||
type Primary struct {
|
||||
ID primitive.D `bson:"_id" json:"_id"`
|
||||
}
|
||||
|
||||
// GetID returns an _id
|
||||
func (p *Primary) GetID() (id primitive.D) {
|
||||
return p.ID
|
||||
}
|
||||
|
||||
// SetID sets an _id
|
||||
func (p *Primary) SetID(id primitive.D) {
|
||||
p.ID = id
|
||||
}
|
||||
|
||||
// New creates a new Primary structure with a defined _id
|
||||
func New(e primitive.E, ee ...primitive.E) Primary {
|
||||
id := primitive.D{e}
|
||||
if len(ee) > 0 {
|
||||
id = append(id, ee...)
|
||||
}
|
||||
|
||||
return Primary{ID: id}
|
||||
}
|
||||
|
||||
func GetID(source mongox.DocBased) (id primitive.D, err error) {
|
||||
id = source.GetID()
|
||||
if !reflect2.IsNil(id) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
return nil, mongox.ErrUninitializedBase
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package jsonbased_test
|
||||
package docbased_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -8,27 +8,25 @@ import (
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox-testing/database"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/jsonbased"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/docbased"
|
||||
)
|
||||
|
||||
func Test_GetID(t *testing.T) {
|
||||
|
||||
type DocWithObject struct {
|
||||
jsonbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
docbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
|
||||
doc := &DocWithObject{Primary: jsonbased.Primary{ID: primitive.D{{"1", "one"}, {"2", "two"}}}}
|
||||
doc := &DocWithObject{Primary: docbased.New(primitive.E{"1", "one"}, primitive.E{"2", "two"})}
|
||||
|
||||
assert.Equal(t, primitive.D{{"1", "one"}, {"2", "two"}}, doc.GetID())
|
||||
}
|
||||
|
||||
func Test_SetID(t *testing.T) {
|
||||
|
||||
type DocWithObject struct {
|
||||
jsonbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
docbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
|
||||
doc := &DocWithObject{Primary: jsonbased.Primary{ID: primitive.D{{"1", "one"}, {"2", "two"}}}}
|
||||
doc := &DocWithObject{Primary: docbased.New(primitive.E{"1", "one"}, primitive.E{"2", "two"})}
|
||||
|
||||
doc.SetID(primitive.D{{"3", "three"}, {"4", "you"}})
|
||||
|
||||
@@ -37,9 +35,8 @@ func Test_SetID(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_SaveLoad(t *testing.T) {
|
||||
|
||||
type DocWithObjectID struct {
|
||||
jsonbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
docbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
|
||||
db, err := database.NewEphemeral("")
|
||||
@@ -47,9 +44,9 @@ func Test_SaveLoad(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer db.Close()
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
doc1 := &DocWithObjectID{Primary: jsonbased.Primary{ID: primitive.D{{"1", "one"}, {"2", "two"}}}}
|
||||
doc1 := &DocWithObjectID{Primary: docbased.New(primitive.E{"1", "one"}, primitive.E{"2", "two"})}
|
||||
doc2 := &DocWithObjectID{}
|
||||
|
||||
err = db.SaveOne(doc1)
|
||||
@@ -67,12 +64,11 @@ func Test_SaveLoad(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_Marshal(t *testing.T) {
|
||||
|
||||
type DocWithObjectID struct {
|
||||
jsonbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
docbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
|
||||
doc := &DocWithObjectID{Primary: jsonbased.Primary{ID: primitive.D{{"1", "one"}, {"2", "two"}}}}
|
||||
doc := &DocWithObjectID{Primary: docbased.New(primitive.E{"1", "one"}, primitive.E{"2", "two"})}
|
||||
|
||||
bytes, err := json.Marshal(doc)
|
||||
assert.NoError(t, err)
|
||||
+11
-55
@@ -2,70 +2,26 @@ package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/modern-go/reflect2"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/docbased"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/ifacebased"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/oidbased"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/stringbased"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox"
|
||||
)
|
||||
|
||||
// GetID returns source document id
|
||||
func GetID(source interface{}) (id interface{}) {
|
||||
|
||||
func GetID(source interface{}) (id interface{}, err error) {
|
||||
switch doc := source.(type) {
|
||||
case mongox.OIDBased:
|
||||
return getObjectIDOrGenerate(doc)
|
||||
return oidbased.GetID(doc)
|
||||
case mongox.StringBased:
|
||||
return getStringIDOrPanic(doc)
|
||||
case mongox.JSONBased:
|
||||
return getObjectOrPanic(doc)
|
||||
return stringbased.GetID(doc)
|
||||
case mongox.DocBased:
|
||||
return docbased.GetID(doc)
|
||||
case mongox.InterfaceBased:
|
||||
return getInterfaceOrPanic(doc)
|
||||
|
||||
return ifacebased.GetID(doc)
|
||||
default:
|
||||
panic(fmt.Errorf("source contains malformed document, %v", source))
|
||||
return nil, fmt.Errorf("%w: unknown base type", mongox.ErrMalformedBase)
|
||||
}
|
||||
}
|
||||
|
||||
func getObjectIDOrGenerate(source mongox.OIDBased) (id primitive.ObjectID) {
|
||||
|
||||
id = source.GetID()
|
||||
if id != primitive.NilObjectID {
|
||||
return id
|
||||
}
|
||||
|
||||
id = primitive.NewObjectID()
|
||||
source.SetID(id)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getStringIDOrPanic(source mongox.StringBased) (id string) {
|
||||
|
||||
id = source.GetID()
|
||||
if id != "" {
|
||||
return id
|
||||
}
|
||||
|
||||
panic(fmt.Errorf("source contains malformed document, %v", source))
|
||||
}
|
||||
|
||||
func getObjectOrPanic(source mongox.JSONBased) (id primitive.D) {
|
||||
|
||||
id = source.GetID()
|
||||
if id != nil {
|
||||
return id
|
||||
}
|
||||
|
||||
panic(fmt.Errorf("source contains malformed document, %v", source))
|
||||
}
|
||||
|
||||
func getInterfaceOrPanic(source mongox.InterfaceBased) (id interface{}) {
|
||||
|
||||
id = source.GetID()
|
||||
if !reflect2.IsNil(id) {
|
||||
return id
|
||||
}
|
||||
|
||||
panic(fmt.Errorf("source contains malformed document, %v", source))
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package base_test
|
||||
|
||||
import (
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/docbased"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/jsonbased"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/oidbased"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/stringbased"
|
||||
)
|
||||
@@ -29,15 +30,24 @@ func TestGetID(t *testing.T) {
|
||||
type DocWithObjectID struct {
|
||||
oidbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
id, err := base.GetID(&DocWithObjectID{Primary: oidbased.Primary{ID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}}})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}), id)
|
||||
|
||||
type DocWithObject struct {
|
||||
jsonbased.Primary `bson:",inline" json:",inline" collection:"2"`
|
||||
docbased.Primary `bson:",inline" json:",inline" collection:"2"`
|
||||
}
|
||||
id, err = base.GetID(&DocWithObject{Primary: docbased.Primary{ID: primitive.D{{"1", "2"}}}})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, primitive.D{{"1", "2"}}, id)
|
||||
|
||||
type DocWithString struct {
|
||||
stringbased.Primary `bson:",inline" json:",inline" collection:"3"`
|
||||
}
|
||||
id, err = base.GetID(&DocWithString{Primary: stringbased.Primary{ID: "foobar"}})
|
||||
assert.Equal(t, "foobar", id)
|
||||
|
||||
assert.Equal(t, primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}), base.GetID(&DocWithObjectID{Primary: oidbased.Primary{ID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}}}))
|
||||
assert.Equal(t, primitive.D{{"1", "2"}}, base.GetID(&DocWithObject{Primary: jsonbased.Primary{ID: primitive.D{{"1", "2"}}}}))
|
||||
assert.Equal(t, "foobar", base.GetID(&DocWithString{Primary: stringbased.Primary{ID: "foobar"}}))
|
||||
assert.Equal(t, 420, base.GetID(&DocWithCustomInterface{ID: 420}))
|
||||
id, err = base.GetID(&DocWithCustomInterface{ID: 420})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 420, id)
|
||||
}
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection"
|
||||
)
|
||||
|
||||
// GetProtection function finds protection field in the source document otherwise returns nil
|
||||
func GetProtection(source interface{}) (key *protection.Key) {
|
||||
|
||||
v := reflect.ValueOf(source)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
el := v.Elem()
|
||||
numField := el.NumField()
|
||||
|
||||
for i := 0; i < numField; i++ {
|
||||
field := el.Field(i)
|
||||
if !field.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
switch field.Interface().(type) {
|
||||
case *protection.Key:
|
||||
key = field.Interface().(*protection.Key)
|
||||
case protection.Key:
|
||||
ptr := field.Addr()
|
||||
key = ptr.Interface().(*protection.Key)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package ifacebased
|
||||
|
||||
import (
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox"
|
||||
"github.com/modern-go/reflect2"
|
||||
)
|
||||
|
||||
// GetID returns an _id from the source document
|
||||
func GetID(source mongox.InterfaceBased) (id interface{}, err error) {
|
||||
id = source.GetID()
|
||||
if !reflect2.IsNil(id) {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
return nil, mongox.ErrUninitializedBase
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
package jsonbased
|
||||
|
||||
import (
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox"
|
||||
)
|
||||
|
||||
var _ mongox.JSONBased = (*Primary)(nil)
|
||||
|
||||
// Primary is a structure with object as an _id field
|
||||
type Primary struct {
|
||||
ID primitive.D `bson:"_id" json:"_id"`
|
||||
}
|
||||
|
||||
// GetID returns an _id
|
||||
func (p *Primary) GetID() (id primitive.D) {
|
||||
return p.ID
|
||||
}
|
||||
|
||||
// SetID sets an _id
|
||||
func (p *Primary) SetID(id primitive.D) {
|
||||
p.ID = id
|
||||
}
|
||||
@@ -22,3 +22,22 @@ func (p *Primary) GetID() (id primitive.ObjectID) {
|
||||
func (p *Primary) SetID(id primitive.ObjectID) {
|
||||
p.ID = id
|
||||
}
|
||||
|
||||
// Generate creates a new Primary structure with a new objectId
|
||||
func Generate() Primary {
|
||||
return Primary{ID: primitive.NewObjectID()}
|
||||
}
|
||||
|
||||
// New creates a new Primary structure with a defined objectId
|
||||
func New(id primitive.ObjectID) Primary {
|
||||
return Primary{ID: id}
|
||||
}
|
||||
|
||||
func GetID(source mongox.OIDBased) (id primitive.ObjectID, err error) {
|
||||
id = source.GetID()
|
||||
if id != primitive.NilObjectID {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
return primitive.NilObjectID, mongox.ErrUninitializedBase
|
||||
}
|
||||
|
||||
@@ -12,24 +12,22 @@ import (
|
||||
)
|
||||
|
||||
func Test_GetID(t *testing.T) {
|
||||
|
||||
type DocWithObjectID struct {
|
||||
oidbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
|
||||
doc := &DocWithObjectID{Primary: oidbased.Primary{ID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}}}
|
||||
doc := &DocWithObjectID{Primary: oidbased.New([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2})}
|
||||
|
||||
assert.Equal(t, primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}), doc.Primary.ID)
|
||||
assert.Equal(t, primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}), doc.GetID())
|
||||
}
|
||||
|
||||
func Test_SetID(t *testing.T) {
|
||||
|
||||
type DocWithObjectID struct {
|
||||
oidbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
|
||||
doc := &DocWithObjectID{}
|
||||
|
||||
doc := &DocWithObjectID{Primary: oidbased.Generate()}
|
||||
doc.SetID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2})
|
||||
|
||||
assert.Equal(t, primitive.ObjectID([12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2}), doc.Primary.ID)
|
||||
@@ -37,7 +35,6 @@ func Test_SetID(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_SaveLoad(t *testing.T) {
|
||||
|
||||
type DocWithObjectID struct {
|
||||
oidbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
@@ -47,10 +44,10 @@ func Test_SaveLoad(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer db.Close()
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
doc1 := &DocWithObjectID{}
|
||||
doc2 := &DocWithObjectID{}
|
||||
doc1 := &DocWithObjectID{Primary: oidbased.Generate()}
|
||||
doc2 := &DocWithObjectID{Primary: oidbased.Generate()}
|
||||
|
||||
err = db.SaveOne(doc1)
|
||||
assert.NoError(t, err)
|
||||
@@ -67,13 +64,12 @@ func Test_SaveLoad(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_Marshal(t *testing.T) {
|
||||
|
||||
type DocWithObjectID struct {
|
||||
oidbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
|
||||
id, _ := primitive.ObjectIDFromHex("feadbeeffeadbeeffeadbeef")
|
||||
doc := &DocWithObjectID{Primary: oidbased.Primary{ID: id}}
|
||||
doc := &DocWithObjectID{Primary: oidbased.New(id)}
|
||||
|
||||
bytes, err := json.Marshal(doc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package protection
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/modern-go/reflect2"
|
||||
@@ -13,9 +14,8 @@ type Key struct {
|
||||
V int64 `bson:"_v" json:"_v"`
|
||||
}
|
||||
|
||||
// PutToDocument extends the doc with protection key values
|
||||
func (k *Key) PutToDocument(doc primitive.M) {
|
||||
|
||||
// Inject extends the doc with protection key values
|
||||
func (k *Key) Inject(doc primitive.M) {
|
||||
if reflect2.IsNil(doc) {
|
||||
return
|
||||
}
|
||||
@@ -34,3 +34,35 @@ func (k *Key) Restate() {
|
||||
k.X = primitive.NewObjectID()
|
||||
k.V = time.Now().Unix()
|
||||
}
|
||||
|
||||
// Get finds protection field in the source document otherwise returns nil
|
||||
func Get(source interface{}) (key *Key) {
|
||||
v := reflect.ValueOf(source)
|
||||
if v.Kind() != reflect.Ptr || v.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
el := v.Elem()
|
||||
numField := el.NumField()
|
||||
|
||||
for i := 0; i < numField; i++ {
|
||||
field := el.Field(i)
|
||||
if !field.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
switch field.Interface().(type) {
|
||||
case *Key:
|
||||
key = field.Interface().(*Key)
|
||||
case Key:
|
||||
ptr := field.Addr()
|
||||
key = ptr.Interface().(*Key)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
package protection_test
|
||||
|
||||
// TODO:
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
// Reset function creates new zero object for the target pointer
|
||||
func Reset(target interface{}) (created bool) {
|
||||
|
||||
type resetter interface {
|
||||
Reset()
|
||||
}
|
||||
|
||||
@@ -20,3 +20,17 @@ func (p *Primary) GetID() (id string) {
|
||||
func (p *Primary) SetID(id string) {
|
||||
p.ID = id
|
||||
}
|
||||
|
||||
// New creates a new Primary structure with a defined _id
|
||||
func New(id string) Primary {
|
||||
return Primary{ID: id}
|
||||
}
|
||||
|
||||
func GetID(source mongox.StringBased) (id string, err error) {
|
||||
id = source.GetID()
|
||||
if id != "" {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
return "", mongox.ErrUninitializedBase
|
||||
}
|
||||
|
||||
@@ -11,23 +11,21 @@ import (
|
||||
)
|
||||
|
||||
func Test_GetID(t *testing.T) {
|
||||
|
||||
type DocWithString struct {
|
||||
stringbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
|
||||
doc := &DocWithString{Primary: stringbased.Primary{ID: "foobar"}}
|
||||
doc := &DocWithString{Primary: stringbased.New("foobar")}
|
||||
|
||||
assert.Equal(t, "foobar", doc.GetID())
|
||||
}
|
||||
|
||||
func Test_SetID(t *testing.T) {
|
||||
|
||||
type DocWithString struct {
|
||||
stringbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
|
||||
doc := &DocWithString{Primary: stringbased.Primary{ID: "foobar"}}
|
||||
doc := &DocWithString{Primary: stringbased.New("foobar")}
|
||||
|
||||
doc.SetID("rockrockrock")
|
||||
|
||||
@@ -36,7 +34,6 @@ func Test_SetID(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_SaveLoad(t *testing.T) {
|
||||
|
||||
type DocWithObjectID struct {
|
||||
stringbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
@@ -46,9 +43,9 @@ func Test_SaveLoad(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer db.Close()
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
doc1 := &DocWithObjectID{Primary: stringbased.Primary{ID: "foobar"}}
|
||||
doc1 := &DocWithObjectID{Primary: stringbased.New("foobar")}
|
||||
doc2 := &DocWithObjectID{}
|
||||
|
||||
err = db.SaveOne(doc1)
|
||||
@@ -66,12 +63,11 @@ func Test_SaveLoad(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_Marshal(t *testing.T) {
|
||||
|
||||
type DocWithObjectID struct {
|
||||
stringbased.Primary `bson:",inline" json:",inline" collection:"1"`
|
||||
}
|
||||
|
||||
doc := &DocWithObjectID{Primary: stringbased.Primary{ID: "foobar"}}
|
||||
doc := &DocWithObjectID{Primary: stringbased.New("foobar")}
|
||||
|
||||
bytes, err := json.Marshal(doc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type ctxDatabaseKey struct{}
|
||||
|
||||
// GetFromContext function extracts the request data from context
|
||||
func GetFromContext(ctx context.Context) (q *Database, ok bool) {
|
||||
q, ok = ctx.Value(ctxDatabaseKey{}).(*Database)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return q, true
|
||||
}
|
||||
|
||||
// WithContext creates the new context with a database attached
|
||||
func WithContext(ctx context.Context, q *Database) (withQuery context.Context) {
|
||||
db := NewDatabase(ctx, q.Client(), q.Name())
|
||||
return context.WithValue(ctx, ctxDatabaseKey{}, db)
|
||||
}
|
||||
@@ -9,22 +9,29 @@ import (
|
||||
// Count function counts documents in the database by query
|
||||
// target is used only to get collection by tag so it'd be better to use nil ptr here
|
||||
func (d *Database) Count(target interface{}, filters ...interface{}) (result int64, err error) {
|
||||
|
||||
composed, err := query.Compose(filters...)
|
||||
if err != nil {
|
||||
return
|
||||
return -1, err
|
||||
}
|
||||
|
||||
collection := d.GetCollectionOf(target)
|
||||
collection, err := d.GetCollectionOf(target)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
|
||||
m := composed.M()
|
||||
|
||||
opts := options.Count()
|
||||
opts.Limit = composed.Limiter()
|
||||
opts.Skip = composed.Skipper()
|
||||
|
||||
result, err = collection.CountDocuments(ctx, composed.M(), opts)
|
||||
defer func() { _ = composed.OnClose().Invoke(ctx, target) }()
|
||||
|
||||
_ = composed.OnClose().Invoke(ctx, target)
|
||||
result, err = collection.CountDocuments(ctx, m, opts)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
return
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
|
||||
)
|
||||
|
||||
func (d *Database) createCursor(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) {
|
||||
_, hasPreloader := composed.Preloader()
|
||||
if hasPreloader {
|
||||
return d.createAggregateCursor(target, composed)
|
||||
}
|
||||
|
||||
return d.createSimpleCursor(target, composed)
|
||||
}
|
||||
|
||||
func (d *Database) createSimpleCursor(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) {
|
||||
collection, err := d.GetCollectionOf(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := options.Find()
|
||||
opts.Sort = composed.Sorter()
|
||||
opts.Limit = composed.Limiter()
|
||||
opts.Skip = composed.Skipper()
|
||||
|
||||
ctx := d.Context()
|
||||
m := composed.M()
|
||||
|
||||
return collection.Find(ctx, m, opts)
|
||||
}
|
||||
|
||||
func (d *Database) createAggregateCursor(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) {
|
||||
collection, err := d.GetCollectionOf(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pipeline := primitive.A{}
|
||||
if !composed.Empty() {
|
||||
pipeline = append(pipeline, primitive.M{"$match": composed.M()})
|
||||
}
|
||||
if composed.Sorter() != nil {
|
||||
pipeline = append(pipeline, primitive.M{"$sort": composed.Sorter()})
|
||||
}
|
||||
if composed.Skipper() != nil {
|
||||
pipeline = append(pipeline, primitive.M{"$skip": *composed.Skipper()})
|
||||
}
|
||||
if composed.Limiter() != nil {
|
||||
pipeline = append(pipeline, primitive.M{"$limit": *composed.Limiter()})
|
||||
}
|
||||
|
||||
el := reflect.ValueOf(target)
|
||||
elType := el.Type()
|
||||
if elType.Kind() == reflect.Ptr {
|
||||
elType = elType.Elem()
|
||||
}
|
||||
|
||||
numField := elType.NumField()
|
||||
preloads, _ := composed.Preloader()
|
||||
for i := 0; i < numField; i++ {
|
||||
field := elType.Field(i)
|
||||
tag := field.Tag
|
||||
|
||||
preloadTag, ok := tag.Lookup("preload")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
jsonTag, _ := tag.Lookup("json")
|
||||
if jsonTag == "-" {
|
||||
return nil, fmt.Errorf("%w: private field is not preloadable", mongox.ErrMalformedBase)
|
||||
}
|
||||
|
||||
jsonData := strings.SplitN(jsonTag, ",", 2)
|
||||
jsonName := field.Name
|
||||
if len(jsonData) > 0 {
|
||||
jsonName = strings.TrimSpace(jsonData[0])
|
||||
}
|
||||
|
||||
preloadData := strings.Split(preloadTag, ",")
|
||||
if len(preloadData) == 0 {
|
||||
continue
|
||||
}
|
||||
if len(preloadData) == 1 {
|
||||
return nil, fmt.Errorf("%w: foreign field is not specified", mongox.ErrMalformedBase)
|
||||
}
|
||||
|
||||
foreignField := strings.TrimSpace(preloadData[1])
|
||||
if len(foreignField) == 0 {
|
||||
return nil, fmt.Errorf("%w: foreign field is empty", mongox.ErrMalformedBase)
|
||||
}
|
||||
localField := strings.TrimSpace(preloadData[0])
|
||||
if len(localField) == 0 {
|
||||
localField = "_id"
|
||||
}
|
||||
|
||||
preloadLimiter := 100
|
||||
preloadReversed := false
|
||||
if len(preloadData) > 2 {
|
||||
stringLimit := strings.TrimSpace(preloadData[2])
|
||||
intLimit := preloadLimiter
|
||||
|
||||
preloadReversed = strings.HasPrefix(stringLimit, "-")
|
||||
if preloadReversed {
|
||||
stringLimit = stringLimit[1:]
|
||||
}
|
||||
|
||||
intLimit, err = strconv.Atoi(stringLimit)
|
||||
if err == nil {
|
||||
preloadLimiter = intLimit
|
||||
} else {
|
||||
return nil, fmt.Errorf("%w: preload limit should be an integer", mongox.ErrMalformedBase)
|
||||
}
|
||||
}
|
||||
|
||||
for _, preload := range preloads {
|
||||
if preload != jsonName {
|
||||
continue
|
||||
}
|
||||
|
||||
field := elType.Field(i)
|
||||
fieldType := field.Type
|
||||
|
||||
isSlice := fieldType.Kind() == reflect.Slice
|
||||
if isSlice {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
isPtr := fieldType.Kind() != reflect.Ptr
|
||||
if isPtr {
|
||||
return nil, fmt.Errorf("%w: preload field should have ptr type", mongox.ErrMalformedBase)
|
||||
}
|
||||
|
||||
lookupCollection, err := d.GetCollectionOf(reflect.Zero(fieldType).Interface())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lookupVars := primitive.M{"selector": "$" + localField}
|
||||
lookupPipeline := primitive.A{
|
||||
primitive.M{"$match": primitive.M{"$expr": primitive.M{"$eq": primitive.A{"$" + foreignField, "$$selector"}}}},
|
||||
}
|
||||
|
||||
if preloadReversed {
|
||||
lookupPipeline = append(lookupPipeline, primitive.M{"$sort": primitive.M{"_id": -1}})
|
||||
}
|
||||
if isSlice && preloadLimiter > 0 {
|
||||
lookupPipeline = append(lookupPipeline, primitive.M{"$limit": preloadLimiter})
|
||||
} else if !isSlice {
|
||||
lookupPipeline = append(lookupPipeline, primitive.M{"$limit": 1})
|
||||
}
|
||||
|
||||
pipeline = append(pipeline, primitive.M{
|
||||
"$lookup": primitive.M{
|
||||
"from": lookupCollection.Name(),
|
||||
"let": lookupVars,
|
||||
"pipeline": lookupPipeline,
|
||||
"as": jsonName,
|
||||
},
|
||||
})
|
||||
|
||||
if isSlice {
|
||||
continue
|
||||
}
|
||||
|
||||
pipeline = append(pipeline, primitive.M{
|
||||
"$unwind": primitive.M{
|
||||
"preserveNullAndEmptyArrays": true,
|
||||
"path": "$" + jsonName,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
ctx := d.Context()
|
||||
opts := options.Aggregate()
|
||||
|
||||
return collection.Aggregate(ctx, pipeline, opts)
|
||||
}
|
||||
+22
-200
@@ -2,34 +2,27 @@ package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
|
||||
)
|
||||
|
||||
// Database handler
|
||||
type Database struct {
|
||||
client *mongox.Client
|
||||
dbname string
|
||||
name string
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewDatabase function creates new database instance with mongo client and empty context
|
||||
func NewDatabase(client *mongox.Client, dbname string) (db mongox.Database) {
|
||||
|
||||
func NewDatabase(ctx context.Context, client *mongox.Client, name string) (db mongox.Database) {
|
||||
db = &Database{
|
||||
client: client,
|
||||
dbname: dbname,
|
||||
name: name,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
return
|
||||
return db
|
||||
}
|
||||
|
||||
// Client function returns a mongo client
|
||||
@@ -37,213 +30,42 @@ func (d *Database) Client() (client *mongox.Client) {
|
||||
return d.client
|
||||
}
|
||||
|
||||
// Name function returns a database name
|
||||
func (d *Database) Name() (name string) {
|
||||
return d.name
|
||||
}
|
||||
|
||||
// Context function returns a context
|
||||
func (d *Database) Context() (ctx context.Context) {
|
||||
|
||||
ctx = d.ctx
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Name function returns a database name
|
||||
func (d *Database) Name() (name string) {
|
||||
return d.dbname
|
||||
}
|
||||
|
||||
// New function creates new database context with same client
|
||||
func (d *Database) New(ctx context.Context) (db mongox.Database) {
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
db = &Database{
|
||||
client: d.client,
|
||||
dbname: d.dbname,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
return
|
||||
return ctx
|
||||
}
|
||||
|
||||
// GetCollectionOf returns the collection object by the «collection» tag of the given document;
|
||||
// the «collection» tag should exists, e.g.:
|
||||
// type Foobar struct {
|
||||
// base.ObjectID `bson:",inline" json:",inline" collection:"foobars"`
|
||||
// ...
|
||||
// Will panic if there is no «collection» tag
|
||||
func (d *Database) GetCollectionOf(document interface{}) (collection *mongox.Collection) {
|
||||
|
||||
//
|
||||
// example:
|
||||
// type Foobar struct {
|
||||
// base.ObjectID `bson:",inline" json:",inline" collection:"foobars"`
|
||||
// ...
|
||||
func (d *Database) GetCollectionOf(document interface{}) (collection *mongox.Collection, err error) {
|
||||
el := reflect.TypeOf(document).Elem()
|
||||
numField := el.NumField()
|
||||
databaseName := d.name
|
||||
|
||||
for i := 0; i < numField; i++ {
|
||||
field := el.Field(i)
|
||||
tag := field.Tag
|
||||
found, ok := tag.Lookup("collection")
|
||||
if !ok {
|
||||
collectionName, found := tag.Lookup("collection")
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
|
||||
return d.client.Database(d.dbname).Collection(found)
|
||||
return d.client.Database(databaseName).Collection(collectionName), nil
|
||||
}
|
||||
|
||||
panic(fmt.Errorf("document %v does not have a collection tag", document))
|
||||
}
|
||||
|
||||
func (d *Database) createSimpleLoad(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) {
|
||||
|
||||
collection := d.GetCollectionOf(target)
|
||||
opts := options.Find()
|
||||
|
||||
opts.Sort = composed.Sorter()
|
||||
opts.Limit = composed.Limiter()
|
||||
opts.Skip = composed.Skipper()
|
||||
|
||||
return collection.Find(d.Context(), composed.M(), opts)
|
||||
}
|
||||
|
||||
func (d *Database) createAggregateLoad(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) {
|
||||
|
||||
collection := d.GetCollectionOf(target)
|
||||
opts := options.Aggregate()
|
||||
|
||||
pipeline := primitive.A{}
|
||||
|
||||
if !composed.Empty() {
|
||||
pipeline = append(pipeline, primitive.M{"$match": composed.M()})
|
||||
}
|
||||
if composed.Sorter() != nil {
|
||||
pipeline = append(pipeline, primitive.M{"$sort": composed.Sorter()})
|
||||
}
|
||||
if composed.Skipper() != nil {
|
||||
pipeline = append(pipeline, primitive.M{"$skip": *composed.Skipper()})
|
||||
}
|
||||
if composed.Limiter() != nil {
|
||||
pipeline = append(pipeline, primitive.M{"$limit": *composed.Limiter()})
|
||||
}
|
||||
|
||||
el := reflect.ValueOf(target)
|
||||
elType := el.Type()
|
||||
if elType.Kind() == reflect.Ptr {
|
||||
elType = elType.Elem()
|
||||
}
|
||||
|
||||
numField := elType.NumField()
|
||||
preloads, _ := composed.Preloader()
|
||||
|
||||
for i := 0; i < numField; i++ {
|
||||
|
||||
field := elType.Field(i)
|
||||
tag := field.Tag
|
||||
|
||||
preloadTag, ok := tag.Lookup("preload")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
jsonTag, _ := tag.Lookup("json")
|
||||
if jsonTag == "-" {
|
||||
panic(fmt.Errorf("preload private field is impossible"))
|
||||
}
|
||||
|
||||
jsonData := strings.SplitN(jsonTag, ",", 2)
|
||||
jsonName := field.Name
|
||||
if len(jsonData) > 0 {
|
||||
jsonName = strings.TrimSpace(jsonData[0])
|
||||
}
|
||||
|
||||
preloadData := strings.Split(preloadTag, ",")
|
||||
if len(preloadData) == 0 {
|
||||
continue
|
||||
}
|
||||
if len(preloadData) == 1 {
|
||||
panic(fmt.Errorf("there is no foreign field"))
|
||||
}
|
||||
|
||||
localField := strings.TrimSpace(preloadData[0])
|
||||
if len(localField) == 0 {
|
||||
localField = "_id"
|
||||
}
|
||||
|
||||
foreignField := strings.TrimSpace(preloadData[1])
|
||||
if len(foreignField) == 0 {
|
||||
panic(fmt.Errorf("there is no foreign field"))
|
||||
}
|
||||
|
||||
preloadLimiter := 100
|
||||
preloadReversed := false
|
||||
if len(preloadData) > 2 {
|
||||
stringLimit := strings.TrimSpace(preloadData[2])
|
||||
intLimit := preloadLimiter
|
||||
|
||||
preloadReversed = strings.HasPrefix(stringLimit, "-")
|
||||
if preloadReversed {
|
||||
stringLimit = stringLimit[1:]
|
||||
}
|
||||
|
||||
intLimit, err = strconv.Atoi(stringLimit)
|
||||
if err == nil {
|
||||
preloadLimiter = intLimit
|
||||
}
|
||||
}
|
||||
|
||||
for _, preload := range preloads {
|
||||
if preload != jsonName {
|
||||
continue
|
||||
}
|
||||
|
||||
field := elType.Field(i)
|
||||
fieldType := field.Type
|
||||
|
||||
isSlice := fieldType.Kind() == reflect.Slice
|
||||
if isSlice {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
isPtr := fieldType.Kind() != reflect.Ptr
|
||||
if isPtr {
|
||||
panic(fmt.Errorf("preload field should have ptr type"))
|
||||
}
|
||||
|
||||
lookupCollection := d.GetCollectionOf(reflect.Zero(fieldType).Interface())
|
||||
lookupVars := primitive.M{"selector": "$" + localField}
|
||||
lookupPipeline := primitive.A{
|
||||
primitive.M{"$match": primitive.M{"$expr": primitive.M{"$eq": primitive.A{"$" + foreignField, "$$selector"}}}},
|
||||
}
|
||||
|
||||
if preloadReversed {
|
||||
lookupPipeline = append(lookupPipeline, primitive.M{"$sort": primitive.M{"_id": -1}})
|
||||
}
|
||||
if isSlice && preloadLimiter > 0 {
|
||||
lookupPipeline = append(lookupPipeline, primitive.M{"$limit": preloadLimiter})
|
||||
} else if !isSlice {
|
||||
lookupPipeline = append(lookupPipeline, primitive.M{"$limit": 1})
|
||||
}
|
||||
|
||||
pipeline = append(pipeline, primitive.M{
|
||||
"$lookup": primitive.M{
|
||||
"from": lookupCollection.Name(),
|
||||
"let": lookupVars,
|
||||
"pipeline": lookupPipeline,
|
||||
"as": jsonName,
|
||||
},
|
||||
})
|
||||
|
||||
if isSlice {
|
||||
continue
|
||||
}
|
||||
|
||||
pipeline = append(pipeline, primitive.M{
|
||||
"$unwind": primitive.M{
|
||||
"preserveNullAndEmptyArrays": true,
|
||||
"path": "$" + jsonName,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return collection.Aggregate(d.Context(), pipeline, opts)
|
||||
return nil, mongox.ErrNoCollection
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
// DeleteArray removes documents list from a database by their ids
|
||||
func (d *Database) DeleteArray(target interface{}, filters ...interface{}) (err error) {
|
||||
|
||||
targetV := reflect.ValueOf(target)
|
||||
targetT := targetV.Type()
|
||||
|
||||
@@ -36,42 +35,43 @@ func (d *Database) DeleteArray(target interface{}, filters ...interface{}) (err
|
||||
zeroElem := reflect.Zero(targetSliceElemT)
|
||||
targetLen := targetSliceV.Len()
|
||||
|
||||
collection, err := d.GetCollectionOf(zeroElem.Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
composed, err := query.Compose(filters...)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
collection := d.GetCollectionOf(zeroElem.Interface())
|
||||
ids := primitive.A{}
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
if targetLen > 0 {
|
||||
var ids primitive.A
|
||||
for i := 0; i < targetLen; i++ {
|
||||
elem := targetSliceV.Index(i)
|
||||
elemID, err := base.GetID(elem.Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < targetLen; i++ {
|
||||
elem := targetSliceV.Index(i)
|
||||
ids = append(ids, base.GetID(elem.Interface()))
|
||||
}
|
||||
|
||||
defer func() {
|
||||
invokerr := composed.OnClose().Invoke(ctx, target)
|
||||
if err == nil {
|
||||
err = invokerr
|
||||
ids = append(ids, elemID)
|
||||
}
|
||||
|
||||
return
|
||||
}()
|
||||
|
||||
if len(ids) == 0 {
|
||||
return fmt.Errorf("can't delete zero elements")
|
||||
composed.And(primitive.M{"_id": primitive.M{"$in": ids}})
|
||||
}
|
||||
|
||||
composed.And(primitive.M{"_id": primitive.M{"$in": ids}})
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
m := composed.M()
|
||||
opts := options.Delete()
|
||||
|
||||
result, err := collection.DeleteMany(ctx, composed.M(), options.Delete())
|
||||
defer func() { _ = composed.OnClose().Invoke(ctx, target) }()
|
||||
|
||||
result, err := collection.DeleteMany(ctx, m, opts)
|
||||
if err != nil {
|
||||
return
|
||||
return fmt.Errorf("while deleting array: %w", err)
|
||||
}
|
||||
if result.DeletedCount != int64(targetLen) {
|
||||
err = fmt.Errorf("can't verify delete result: removed count mismatch %d != %d", result.DeletedCount, targetLen)
|
||||
return fmt.Errorf("deleted count mismatch %d != %d", result.DeletedCount, targetLen)
|
||||
}
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection"
|
||||
|
||||
"github.com/modern-go/reflect2"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
@@ -13,51 +14,53 @@ import (
|
||||
|
||||
// DeleteOne removes a document from a database and then returns it into target
|
||||
func (d *Database) DeleteOne(target interface{}, filters ...interface{}) (err error) {
|
||||
|
||||
composed, err := query.Compose(filters...)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
collection := d.GetCollectionOf(target)
|
||||
protected := base.GetProtection(target)
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
|
||||
opts := options.FindOneAndDelete()
|
||||
opts.Sort = composed.Sorter()
|
||||
collection, err := d.GetCollectionOf(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !reflect2.IsNil(target) {
|
||||
composed.And(primitive.M{"_id": base.GetID(target)})
|
||||
targetID, err := base.GetID(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
composed.And(primitive.M{"_id": targetID})
|
||||
}
|
||||
|
||||
protected := protection.Get(target)
|
||||
if protected != nil {
|
||||
query.Push(composed, protected)
|
||||
_, err := query.Push(composed, protected)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protected.Restate()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
invokerr := composed.OnClose().Invoke(ctx, target)
|
||||
if err == nil {
|
||||
err = invokerr
|
||||
}
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
m := composed.M()
|
||||
opts := options.FindOneAndDelete()
|
||||
opts.Sort = composed.Sorter()
|
||||
|
||||
return
|
||||
}()
|
||||
defer func() { _ = composed.OnClose().Invoke(ctx, target) }()
|
||||
|
||||
result := collection.FindOneAndDelete(ctx, composed.M(), opts)
|
||||
result := collection.FindOneAndDelete(ctx, m, opts)
|
||||
if result.Err() != nil {
|
||||
return fmt.Errorf("can't create find one and delete result: %w", result.Err())
|
||||
}
|
||||
|
||||
err = result.Decode(target)
|
||||
if err != nil {
|
||||
return
|
||||
return fmt.Errorf("can't decode find one and delete result: %w", err)
|
||||
}
|
||||
|
||||
err = composed.OnDecode().Invoke(ctx, target)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = composed.OnDecode().Invoke(ctx, target)
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
+13
-10
@@ -14,18 +14,21 @@ import (
|
||||
)
|
||||
|
||||
// IndexEnsure function ensures index in mongo collection of document
|
||||
// `index:""` -- https://docs.mongodb.com/manual/indexes/#create-an-index
|
||||
// `index:"-"` -- (descending)
|
||||
// `index:"-,+foo,+-bar"` -- https://docs.mongodb.com/manual/core/index-compound
|
||||
// `index:"-,+foo,+-bar,unique"` -- https://docs.mongodb.com/manual/core/index-unique
|
||||
// `index:"-,+foo,+-bar,unique,allowNull"` -- https://docs.mongodb.com/manual/core/index-partial
|
||||
// `index:"-,unique,allowNull,expireAfter=86400"` -- https://docs.mongodb.com/manual/core/index-ttl
|
||||
// `index:"-,unique,allowNull,expireAfter={{.Expire}}"` -- evaluate index as a golang template with `cfg` arguments
|
||||
//
|
||||
// `index:""` -- https://docs.mongodb.com/manual/indexes/#create-an-index
|
||||
// `index:"-"` -- (descending)
|
||||
// `index:"-,+foo,+-bar"` -- https://docs.mongodb.com/manual/core/index-compound
|
||||
// `index:"-,+foo,+-bar,unique"` -- https://docs.mongodb.com/manual/core/index-unique
|
||||
// `index:"-,+foo,+-bar,unique,allowNull"` -- https://docs.mongodb.com/manual/core/index-partial
|
||||
// `index:"-,unique,allowNull,expireAfter=86400"` -- https://docs.mongodb.com/manual/core/index-ttl
|
||||
// `index:"-,unique,allowNull,expireAfter={{.Expire}}"` -- evaluate index as a golang template with `cfg` arguments
|
||||
func (d *Database) IndexEnsure(cfg interface{}, document interface{}) (err error) {
|
||||
|
||||
el := reflect.ValueOf(document).Elem().Type()
|
||||
numField := el.NumField()
|
||||
documents := d.GetCollectionOf(document)
|
||||
collection, err := d.GetCollectionOf(document)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < numField; i++ {
|
||||
|
||||
@@ -126,7 +129,7 @@ func (d *Database) IndexEnsure(cfg interface{}, document interface{}) (err error
|
||||
}
|
||||
}
|
||||
|
||||
_, err = documents.Indexes().CreateOne(d.Context(), mongo.IndexModel{Keys: index, Options: opts})
|
||||
_, err = collection.Indexes().CreateOne(d.Context(), mongo.IndexModel{Keys: index, Options: opts})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package database_test
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -142,7 +143,10 @@ func TestDatabase_Ensure(t *testing.T) {
|
||||
err = db.IndexEnsure(tt.settings, tt.doc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
indexes, _ := db.GetCollectionOf(tt.doc).Indexes().List(db.Context())
|
||||
collection, err := db.GetCollectionOf(tt.doc)
|
||||
require.NoError(t, err)
|
||||
|
||||
indexes, _ := collection.Indexes().List(db.Context())
|
||||
index := new(map[string]interface{})
|
||||
|
||||
indexes.Next(db.Context()) // skip _id_
|
||||
|
||||
@@ -4,14 +4,12 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
|
||||
)
|
||||
|
||||
// LoadArray loads an array of documents from the database by query
|
||||
func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err error) {
|
||||
|
||||
targetV := reflect.ValueOf(target)
|
||||
targetT := targetV.Type()
|
||||
|
||||
@@ -31,61 +29,36 @@ func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err er
|
||||
panic(fmt.Errorf("target slice should contain ptrs"))
|
||||
}
|
||||
|
||||
zeroElem := reflect.Zero(targetSliceElemT)
|
||||
|
||||
composed, err := query.Compose(filters...)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
zeroElem := reflect.Zero(targetSliceElemT)
|
||||
_, hasPreloader := composed.Preloader()
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
|
||||
var result *mongox.Cursor
|
||||
var i int
|
||||
defer func() { _ = composed.OnClose().Invoke(ctx, target) }()
|
||||
|
||||
defer func() {
|
||||
|
||||
if result != nil {
|
||||
closerr := result.Close(ctx)
|
||||
if err == nil {
|
||||
err = closerr
|
||||
}
|
||||
}
|
||||
|
||||
invokerr := composed.OnClose().Invoke(ctx, target)
|
||||
if err == nil {
|
||||
err = invokerr
|
||||
}
|
||||
|
||||
return
|
||||
}()
|
||||
|
||||
if hasPreloader {
|
||||
result, err = d.createAggregateLoad(zeroElem.Interface(), composed)
|
||||
} else {
|
||||
result, err = d.createSimpleLoad(zeroElem.Interface(), composed)
|
||||
}
|
||||
cur, err := d.createCursor(zeroElem.Interface(), composed)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("can't create find result: %w", err)
|
||||
return
|
||||
return fmt.Errorf("can't create find result: %w", err)
|
||||
}
|
||||
|
||||
for i = 0; result.Next(ctx); i++ {
|
||||
defer func() { _ = cur.Close(ctx) }()
|
||||
|
||||
var i int
|
||||
for i = 0; cur.Next(ctx); i++ {
|
||||
var elem interface{}
|
||||
|
||||
if i == targetSliceV.Len() {
|
||||
value := reflect.New(targetSliceElemT.Elem())
|
||||
elem = value.Interface()
|
||||
|
||||
err = composed.OnCreate().Invoke(ctx, elem)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = composed.OnCreate().Invoke(ctx, elem)
|
||||
|
||||
err = result.Decode(elem)
|
||||
err = cur.Decode(elem)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
targetSliceV = reflect.Append(targetSliceV, value)
|
||||
@@ -93,26 +66,24 @@ func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err er
|
||||
elem = targetSliceV.Index(i).Interface()
|
||||
|
||||
if created := base.Reset(elem); created {
|
||||
err = composed.OnCreate().Invoke(ctx, elem)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
_ = composed.OnCreate().Invoke(ctx, elem)
|
||||
}
|
||||
|
||||
err = result.Decode(elem)
|
||||
err = cur.Decode(elem)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = composed.OnDecode().Invoke(ctx, elem)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = composed.OnDecode().Invoke(ctx, elem)
|
||||
}
|
||||
err = cur.Err()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetSliceV = targetSliceV.Slice(0, i)
|
||||
targetV.Elem().Set(targetSliceV)
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
+13
-44
@@ -1,8 +1,6 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
|
||||
@@ -10,68 +8,39 @@ import (
|
||||
|
||||
// LoadOne function loads a first single target document by a query
|
||||
func (d *Database) LoadOne(target interface{}, filters ...interface{}) (err error) {
|
||||
|
||||
composed, err := query.Compose(append(filters, query.Limit(1))...)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
_, hasPreloader := composed.Preloader()
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
|
||||
var result *mongox.Cursor
|
||||
defer func() { _ = composed.OnClose().Invoke(ctx, target) }()
|
||||
|
||||
defer func() {
|
||||
|
||||
if result != nil {
|
||||
closerr := result.Close(ctx)
|
||||
if err == nil {
|
||||
err = closerr
|
||||
}
|
||||
}
|
||||
|
||||
invokerr := composed.OnClose().Invoke(ctx, target)
|
||||
if err == nil {
|
||||
err = invokerr
|
||||
}
|
||||
|
||||
return
|
||||
}()
|
||||
|
||||
if hasPreloader {
|
||||
result, err = d.createAggregateLoad(target, composed)
|
||||
} else {
|
||||
result, err = d.createSimpleLoad(target, composed)
|
||||
}
|
||||
cur, err := d.createCursor(target, composed)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't create find result: %w", err)
|
||||
return err
|
||||
}
|
||||
defer func() { _ = cur.Close(ctx) }()
|
||||
|
||||
hasNext := result.Next(ctx)
|
||||
if result.Err() != nil {
|
||||
err = result.Err()
|
||||
return
|
||||
hasNext := cur.Next(ctx)
|
||||
if cur.Err() != nil {
|
||||
return cur.Err()
|
||||
}
|
||||
if !hasNext {
|
||||
return mongox.ErrNoDocuments
|
||||
}
|
||||
|
||||
if created := base.Reset(target); created {
|
||||
err = composed.OnCreate().Invoke(ctx, target)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
_ = composed.OnCreate().Invoke(ctx, target)
|
||||
}
|
||||
|
||||
err = result.Decode(target)
|
||||
err = cur.Decode(target)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
err = composed.OnDecode().Invoke(ctx, target)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = composed.OnDecode().Invoke(ctx, target)
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,36 +1,25 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
|
||||
)
|
||||
|
||||
// LoadStream function loads documents one by one into a target channel
|
||||
func (d *Database) LoadStream(target interface{}, filters ...interface{}) (loader mongox.StreamLoader, err error) {
|
||||
|
||||
composed, err := query.Compose(filters...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, hasPreloader := composed.Preloader()
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
|
||||
var cursor *mongox.Cursor
|
||||
|
||||
if hasPreloader {
|
||||
cursor, err = d.createAggregateLoad(target, composed)
|
||||
} else {
|
||||
cursor, err = d.createSimpleLoad(target, composed)
|
||||
}
|
||||
cur, err := d.createCursor(target, composed)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("can't create find result: %w", err)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loader = &StreamLoader{cur: cursor, ctx: ctx, query: composed}
|
||||
loader = &StreamLoader{cur: cur, ctx: ctx, query: composed}
|
||||
|
||||
return
|
||||
return loader, nil
|
||||
}
|
||||
|
||||
+18
-18
@@ -1,6 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
@@ -10,38 +11,37 @@ import (
|
||||
|
||||
// SaveOne saves a single source document to the database
|
||||
func (d *Database) SaveOne(source interface{}, filters ...interface{}) (err error) {
|
||||
collection, err := d.GetCollectionOf(source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
composed, err := query.Compose(filters...)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
collection := d.GetCollectionOf(source)
|
||||
id := base.GetID(source)
|
||||
protected := base.GetProtection(source)
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
|
||||
id, err := base.GetID(source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
composed.And(primitive.M{"_id": id})
|
||||
|
||||
opts := options.FindOneAndReplace()
|
||||
opts.SetUpsert(true)
|
||||
opts.SetReturnDocument(options.After)
|
||||
|
||||
protected := protection.Get(source)
|
||||
if protected != nil {
|
||||
query.Push(composed, protected)
|
||||
protected.Restate()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
invokerr := composed.OnClose().Invoke(ctx, source)
|
||||
if err == nil {
|
||||
err = invokerr
|
||||
}
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
m := composed.M()
|
||||
opts := options.FindOneAndReplace()
|
||||
opts.SetUpsert(true)
|
||||
opts.SetReturnDocument(options.After)
|
||||
|
||||
return
|
||||
}()
|
||||
defer func() { _ = composed.OnClose().Invoke(ctx, source) }()
|
||||
|
||||
result := collection.FindOneAndReplace(ctx, composed.M(), source, opts)
|
||||
result := collection.FindOneAndReplace(ctx, m, source, opts)
|
||||
if result.Err() != nil {
|
||||
return result.Err()
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ type StreamLoader struct {
|
||||
|
||||
// DecodeNextMsg decodes the next document to an interface or returns an error
|
||||
func (l *StreamLoader) DecodeNextMsg(i interface{}) (err error) {
|
||||
|
||||
err = l.Next()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -33,41 +32,35 @@ func (l *StreamLoader) DecodeNextMsg(i interface{}) (err error) {
|
||||
|
||||
// DecodeMsg decodes the current cursor document into an interface
|
||||
func (l *StreamLoader) DecodeMsg(i interface{}) (err error) {
|
||||
|
||||
if created := base.Reset(i); created {
|
||||
err = l.query.OnDecode().Invoke(l.ctx, i)
|
||||
_ = l.query.OnDecode().Invoke(l.ctx, i)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
err = l.cur.Decode(i)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
err = l.query.OnDecode().Invoke(l.ctx, i)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = l.query.OnDecode().Invoke(l.ctx, i)
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Next loads next documents but doesn't perform decoding
|
||||
func (l *StreamLoader) Next() (err error) {
|
||||
|
||||
hasNext := l.cur.Next(l.ctx)
|
||||
err = l.cur.Err()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
if !hasNext {
|
||||
err = mongox.ErrNoDocuments
|
||||
return mongox.ErrNoDocuments
|
||||
}
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cursor returns the underlying cursor
|
||||
@@ -77,24 +70,21 @@ func (l *StreamLoader) Cursor() (cursor *mongox.Cursor) {
|
||||
|
||||
// Close stream loader and the underlying cursor
|
||||
func (l *StreamLoader) Close() (err error) {
|
||||
defer func() { _ = l.query.OnClose().Invoke(l.ctx, nil) }()
|
||||
|
||||
closerr := l.cur.Close(l.ctx)
|
||||
invokerr := l.query.OnClose().Invoke(l.ctx, nil)
|
||||
|
||||
if closerr != nil {
|
||||
err = closerr
|
||||
return
|
||||
err = l.cur.Close(l.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if invokerr != nil {
|
||||
err = invokerr
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// Err returns the last error
|
||||
func (l *StreamLoader) Err() (err error) {
|
||||
return l.cur.Err()
|
||||
}
|
||||
|
||||
func (l *StreamLoader) Context() (ctx context.Context) {
|
||||
return l.ctx
|
||||
}
|
||||
|
||||
@@ -1,72 +1,64 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection"
|
||||
"github.com/modern-go/reflect2"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
|
||||
)
|
||||
|
||||
// UpdateOne updates a single document in the database and loads it into target
|
||||
func (d *Database) UpdateOne(target interface{}, filters ...interface{}) (err error) {
|
||||
|
||||
composed, err := query.Compose(filters...)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
updaterDoc, err := composed.Updater()
|
||||
update, err := composed.Updater()
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
collection := d.GetCollectionOf(target)
|
||||
protected := base.GetProtection(target)
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
|
||||
opts := options.FindOneAndUpdate()
|
||||
opts.SetReturnDocument(options.After)
|
||||
|
||||
protected := protection.Get(target)
|
||||
if protected != nil {
|
||||
if !protected.X.IsZero() {
|
||||
query.Push(composed, protected)
|
||||
}
|
||||
|
||||
protected.Restate()
|
||||
|
||||
setCmd, _ := updaterDoc["$set"].(primitive.M)
|
||||
setCmd, _ := update["$set"].(primitive.M)
|
||||
if reflect2.IsNil(setCmd) {
|
||||
setCmd = primitive.M{}
|
||||
}
|
||||
protected.PutToDocument(setCmd)
|
||||
updaterDoc["$set"] = setCmd
|
||||
protected.Inject(setCmd)
|
||||
update["$set"] = setCmd
|
||||
}
|
||||
|
||||
defer func() {
|
||||
invokerr := composed.OnClose().Invoke(ctx, target)
|
||||
if err == nil {
|
||||
err = invokerr
|
||||
}
|
||||
collection, err := d.GetCollectionOf(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return
|
||||
}()
|
||||
ctx := query.WithContext(d.Context(), composed)
|
||||
m := composed.M()
|
||||
opts := options.FindOneAndUpdate()
|
||||
opts.SetReturnDocument(options.After)
|
||||
|
||||
result := collection.FindOneAndUpdate(ctx, composed.M(), updaterDoc, opts)
|
||||
defer func() { _ = composed.OnClose().Invoke(ctx, target) }()
|
||||
|
||||
result := collection.FindOneAndUpdate(ctx, m, update, opts)
|
||||
if result.Err() != nil {
|
||||
return result.Err()
|
||||
}
|
||||
|
||||
err = result.Decode(target)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
err = composed.OnDecode().Invoke(ctx, target)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = composed.OnDecode().Invoke(ctx, target)
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package mongox
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
@@ -18,3 +20,9 @@ var (
|
||||
ErrWrongClient = mongo.ErrWrongClient
|
||||
ErrNoDocuments = mongo.ErrNoDocuments
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMalformedBase = errors.New("source contains malformed document base")
|
||||
ErrUninitializedBase = errors.New("uninitialized document")
|
||||
ErrNoCollection = errors.New("no collection found")
|
||||
)
|
||||
|
||||
+3
-4
@@ -19,8 +19,7 @@ type Database interface {
|
||||
Client() (client *Client)
|
||||
Context() (context context.Context)
|
||||
Name() (name string)
|
||||
New(ctx context.Context) (db Database)
|
||||
GetCollectionOf(document interface{}) (collection *Collection)
|
||||
GetCollectionOf(document interface{}) (collection *Collection, err error)
|
||||
Count(target interface{}, filters ...interface{}) (count int64, err error)
|
||||
DeleteArray(target interface{}, filters ...interface{}) (err error)
|
||||
DeleteOne(target interface{}, filters ...interface{}) (err error)
|
||||
@@ -54,8 +53,8 @@ type StringBased interface {
|
||||
SetID(id string)
|
||||
}
|
||||
|
||||
// JSONBased is an interface for documents that have object type for the _id field
|
||||
type JSONBased interface {
|
||||
// DocBased is an interface for documents that have object type for the _id field
|
||||
type DocBased interface {
|
||||
GetID() (id primitive.D)
|
||||
SetID(id primitive.D)
|
||||
}
|
||||
|
||||
@@ -18,9 +18,9 @@ func (c Callbacks) Invoke(ctx context.Context, iter interface{}) (err error) {
|
||||
for _, cb := range c {
|
||||
err = cb(ctx, iter)
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package query_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
|
||||
)
|
||||
|
||||
func TestCallbacks_InvokeOk(t *testing.T) {
|
||||
|
||||
mocked := mock.Mock{}
|
||||
|
||||
var callbacks = query.Callbacks{
|
||||
func(ctx context.Context, iter interface{}) (err error) {
|
||||
return mocked.Called(ctx, iter).Error(0)
|
||||
},
|
||||
func(ctx context.Context, iter interface{}) (err error) {
|
||||
return mocked.Called(ctx, iter).Error(0)
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
iter := int64(42)
|
||||
|
||||
mocked.On("func1", ctx, iter).Return(nil).Once()
|
||||
mocked.On("func2", ctx, iter).Return(nil).Once()
|
||||
|
||||
assert.NoError(t, callbacks.Invoke(ctx, iter))
|
||||
assert.True(t, mocked.AssertExpectations(t))
|
||||
}
|
||||
|
||||
func TestCallbacks_InvokeStopIfError(t *testing.T) {
|
||||
|
||||
mocked := mock.Mock{}
|
||||
|
||||
var callbacks = query.Callbacks{
|
||||
func(ctx context.Context, iter interface{}) (err error) {
|
||||
return mocked.Called(ctx, iter).Error(0)
|
||||
},
|
||||
func(ctx context.Context, iter interface{}) (err error) {
|
||||
t.FailNow()
|
||||
return
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
iter := int(42)
|
||||
|
||||
mocked.On("func1", ctx, iter).Return(fmt.Errorf("wat"))
|
||||
|
||||
assert.EqualError(t, callbacks.Invoke(ctx, iter), "wat")
|
||||
assert.True(t, mocked.AssertExpectations(t))
|
||||
}
|
||||
+16
-28
@@ -9,13 +9,11 @@ import (
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection"
|
||||
)
|
||||
|
||||
type applyFilterFunc = func(query *Query, filter interface{}) (ok bool)
|
||||
type applyFilterFunc func(query *Query, filter interface{}) (ok bool)
|
||||
|
||||
// Compose is a function to compose filters into a single query
|
||||
func Compose(filters ...interface{}) (query *Query, err error) {
|
||||
|
||||
query = &Query{}
|
||||
|
||||
for _, filter := range filters {
|
||||
ok, err := Push(query, filter)
|
||||
if err != nil {
|
||||
@@ -26,25 +24,25 @@ func Compose(filters ...interface{}) (query *Query, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
return query, nil
|
||||
}
|
||||
|
||||
// Push applies single filter to a query
|
||||
func Push(query *Query, filter interface{}) (ok bool, err error) {
|
||||
|
||||
ok = reflect2.IsNil(filter)
|
||||
if ok {
|
||||
return
|
||||
emptyFilter := reflect2.IsNil(filter)
|
||||
if emptyFilter {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
validator, hasValidator := filter.(Validator)
|
||||
if hasValidator {
|
||||
err = validator.Validate()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
err := validator.Validate()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("while validating filter %v, %w", filter, err)
|
||||
}
|
||||
}
|
||||
|
||||
ok = false // true if at least one filter was applied
|
||||
for _, applier := range []applyFilterFunc{
|
||||
applyBson,
|
||||
applyLimit,
|
||||
@@ -58,12 +56,11 @@ func Push(query *Query, filter interface{}) (ok bool, err error) {
|
||||
ok = applier(query, filter) || ok
|
||||
}
|
||||
|
||||
return
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// applyBson is a fallback for a custom primitive.M
|
||||
func applyBson(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
if filter, ok := filter.(primitive.M); ok {
|
||||
query.And(filter)
|
||||
return true
|
||||
@@ -74,7 +71,6 @@ func applyBson(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
// applyLimits extends query with a limiter
|
||||
func applyLimit(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
if filter, ok := filter.(Limiter); ok {
|
||||
query.limiter = filter
|
||||
return true
|
||||
@@ -85,7 +81,6 @@ func applyLimit(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
// applySort extends query with a sort rule
|
||||
func applySort(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
if filter, ok := filter.(Sorter); ok {
|
||||
query.sorter = filter
|
||||
return true
|
||||
@@ -96,7 +91,6 @@ func applySort(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
// applySkip extends query with a skip number
|
||||
func applySkip(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
if filter, ok := filter.(Skipper); ok {
|
||||
query.skipper = filter
|
||||
return true
|
||||
@@ -106,15 +100,12 @@ func applySkip(query *Query, filter interface{}) (ok bool) {
|
||||
}
|
||||
|
||||
func applyProtection(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
var keyDoc = primitive.M{}
|
||||
|
||||
keyDoc := primitive.M{}
|
||||
switch filter := filter.(type) {
|
||||
case protection.Key:
|
||||
filter.PutToDocument(keyDoc)
|
||||
filter.Inject(keyDoc)
|
||||
case *protection.Key:
|
||||
filter.PutToDocument(keyDoc)
|
||||
|
||||
filter.Inject(keyDoc)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -125,7 +116,6 @@ func applyProtection(query *Query, filter interface{}) (ok bool) {
|
||||
}
|
||||
|
||||
func applyPreloader(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
if filter, ok := filter.(Preloader); ok {
|
||||
query.preloader = filter
|
||||
return true
|
||||
@@ -135,7 +125,6 @@ func applyPreloader(query *Query, filter interface{}) (ok bool) {
|
||||
}
|
||||
|
||||
func applyUpdater(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
if filter, ok := filter.(Updater); ok {
|
||||
query.updater = filter
|
||||
return true
|
||||
@@ -145,12 +134,11 @@ func applyUpdater(query *Query, filter interface{}) (ok bool) {
|
||||
}
|
||||
|
||||
func applyCallbacks(query *Query, filter interface{}) (ok bool) {
|
||||
|
||||
switch callback := filter.(type) {
|
||||
case OnDecode:
|
||||
query.ondecode = append(query.ondecode, Callback(callback))
|
||||
query.onDecode = append(query.onDecode, Callback(callback))
|
||||
case OnClose:
|
||||
query.onclose = append(query.onclose, Callback(callback))
|
||||
query.onClose = append(query.onClose, Callback(callback))
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
package query_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection"
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
|
||||
)
|
||||
|
||||
func TestPushBSON(t *testing.T) {
|
||||
|
||||
q := &query.Query{}
|
||||
|
||||
ok, err := query.Push(q, primitive.M{"foo": "bar"})
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, q.M())
|
||||
assert.Len(t, q.M()["$and"], 1)
|
||||
assert.Contains(t, q.M()["$and"], primitive.M{"foo": "bar"})
|
||||
|
||||
ok, err = query.Push(q, primitive.M{"bar": "foo"})
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, q.M())
|
||||
assert.Len(t, q.M()["$and"], 2)
|
||||
assert.Contains(t, q.M()["$and"], primitive.M{"foo": "bar"})
|
||||
assert.Contains(t, q.M()["$and"], primitive.M{"bar": "foo"})
|
||||
}
|
||||
|
||||
func TestPushLimiter(t *testing.T) {
|
||||
|
||||
q := &query.Query{}
|
||||
lim := query.Limit(2)
|
||||
|
||||
ok, err := query.Push(q, lim)
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, q.Limiter())
|
||||
assert.EqualValues(t, q.Limiter(), query.Limit(2).Limit())
|
||||
}
|
||||
|
||||
func TestPushSorter(t *testing.T) {
|
||||
|
||||
q := &query.Query{}
|
||||
sort := query.Sort{"foo": 1}
|
||||
|
||||
ok, err := query.Push(q, sort)
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, q.Sorter())
|
||||
assert.EqualValues(t, q.Sorter(), primitive.M{"foo": 1})
|
||||
}
|
||||
|
||||
func TestPushSkipper(t *testing.T) {
|
||||
|
||||
q := &query.Query{}
|
||||
skip := query.Skip(66)
|
||||
|
||||
ok, err := query.Push(q, skip)
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, q.Skipper())
|
||||
assert.EqualValues(t, q.Skipper(), query.Skip(66).Skip())
|
||||
}
|
||||
|
||||
func TestPushProtection(t *testing.T) {
|
||||
|
||||
t.Run("push protection key pointer", func(t *testing.T) {
|
||||
q := &query.Query{}
|
||||
protected := &protection.Key{V: 1, X: primitive.ObjectID{2}}
|
||||
|
||||
ok, err := query.Push(q, protected)
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, q.M()["$and"])
|
||||
assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.ObjectID{2}, "_v": int64(1)})
|
||||
})
|
||||
|
||||
t.Run("push protection key struct", func(t *testing.T) {
|
||||
q := &query.Query{}
|
||||
protected := protection.Key{V: 1, X: primitive.ObjectID{2}}
|
||||
|
||||
ok, err := query.Push(q, protected)
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, q.M()["$and"])
|
||||
assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.ObjectID{2}, "_v": int64(1)})
|
||||
})
|
||||
|
||||
t.Run("protection key is empty", func(t *testing.T) {
|
||||
q := &query.Query{}
|
||||
protected := &protection.Key{}
|
||||
|
||||
ok, err := query.Push(q, protected)
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, q.M()["$and"])
|
||||
assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.M{"$exists": false}, "_v": primitive.M{"$exists": false}})
|
||||
})
|
||||
}
|
||||
|
||||
func TestPushPreloader(t *testing.T) {
|
||||
|
||||
q := &query.Query{}
|
||||
preloader := query.Preload{"a", "b"}
|
||||
|
||||
ok, err := query.Push(q, preloader)
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.NoError(t, err)
|
||||
|
||||
p, hasPreloader := q.Preloader()
|
||||
|
||||
assert.NotNil(t, p)
|
||||
assert.True(t, hasPreloader)
|
||||
assert.EqualValues(t, p, query.Preload{"a", "b"})
|
||||
}
|
||||
@@ -9,11 +9,14 @@ type ctxQueryKey struct{}
|
||||
// GetFromContext function extracts the request data from context
|
||||
func GetFromContext(ctx context.Context) (q *Query, ok bool) {
|
||||
q, ok = ctx.Value(ctxQueryKey{}).(*Query)
|
||||
return
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return q, true
|
||||
}
|
||||
|
||||
// WithContext function creates the new context with request data
|
||||
func WithContext(ctx context.Context, q *Query) (withQuery context.Context) {
|
||||
withQuery = context.WithValue(ctx, ctxQueryKey{}, q)
|
||||
return
|
||||
return context.WithValue(ctx, ctxQueryKey{}, q)
|
||||
}
|
||||
|
||||
@@ -12,13 +12,12 @@ var _ Limiter = Limit(0)
|
||||
|
||||
// Limit returns a limit
|
||||
func (l Limit) Limit() (limit *int64) {
|
||||
|
||||
if l <= 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
limit = new(int64)
|
||||
*limit = int64(l)
|
||||
|
||||
return
|
||||
return limit
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package query
|
||||
|
||||
// Preloader is a filter to skip the result
|
||||
// Preloader is a filter to preload the result
|
||||
type Preloader interface {
|
||||
Preload() (preloads []string)
|
||||
}
|
||||
|
||||
// Preload is a simple implementation of the Skipper filter
|
||||
// Preload is a simple implementation of the Preloader filter
|
||||
type Preload []string
|
||||
|
||||
var _ Preloader = Preload{}
|
||||
|
||||
+22
-28
@@ -15,35 +15,31 @@ type Query struct {
|
||||
skipper Skipper
|
||||
preloader Preloader
|
||||
updater Updater
|
||||
ondecode Callbacks
|
||||
onclose Callbacks
|
||||
oncreate Callbacks
|
||||
onDecode Callbacks
|
||||
onClose Callbacks
|
||||
onCreate Callbacks
|
||||
}
|
||||
|
||||
// And function pushes the elem query to the $and array of the query
|
||||
func (q *Query) And(elem primitive.M) (query *Query) {
|
||||
|
||||
if q.m == nil {
|
||||
q.m = primitive.M{}
|
||||
}
|
||||
|
||||
queries, exists := q.m["$and"].(primitive.A)
|
||||
|
||||
if !exists {
|
||||
q.m["$and"] = primitive.A{elem}
|
||||
return q
|
||||
}
|
||||
|
||||
q.m["$and"] = append(queries, elem)
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// Limiter returns limiter value or nil
|
||||
func (q *Query) Limiter() (limit *int64) {
|
||||
|
||||
if q.limiter == nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
return q.limiter.Limit()
|
||||
@@ -51,9 +47,8 @@ func (q *Query) Limiter() (limit *int64) {
|
||||
|
||||
// Sorter is a sort rule for a query
|
||||
func (q *Query) Sorter() (sort interface{}) {
|
||||
|
||||
if q.sorter == nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
return q.sorter.Sort()
|
||||
@@ -61,9 +56,8 @@ func (q *Query) Sorter() (sort interface{}) {
|
||||
|
||||
// Skipper is a skipper for a query
|
||||
func (q *Query) Skipper() (skip *int64) {
|
||||
|
||||
if q.skipper == nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
return q.skipper.Skip()
|
||||
@@ -71,62 +65,57 @@ func (q *Query) Skipper() (skip *int64) {
|
||||
|
||||
// Updater is an update command for a query
|
||||
func (q *Query) Updater() (update primitive.M, err error) {
|
||||
|
||||
if q.updater == nil {
|
||||
update = primitive.M{}
|
||||
return
|
||||
return primitive.M{}, nil
|
||||
}
|
||||
|
||||
update = q.updater.Update()
|
||||
|
||||
if reflect2.IsNil(update) {
|
||||
update = primitive.M{}
|
||||
return
|
||||
return primitive.M{}, nil
|
||||
}
|
||||
|
||||
buffer := bytebufferpool.Get()
|
||||
defer bytebufferpool.Put(buffer)
|
||||
|
||||
// convert update document to bson map values
|
||||
buffer.Reset()
|
||||
bsonBytes, err := bson.MarshalAppend(buffer.B, update)
|
||||
if err != nil {
|
||||
return
|
||||
return primitive.M{}, err
|
||||
}
|
||||
update = primitive.M{}
|
||||
update = primitive.M{} // reset update map and unmarshal bson bytes to it again
|
||||
err = bson.Unmarshal(bsonBytes, update)
|
||||
if err != nil {
|
||||
return
|
||||
return primitive.M{}, err
|
||||
}
|
||||
|
||||
return
|
||||
return update, nil
|
||||
}
|
||||
|
||||
// Preloader is a preloader list for a query
|
||||
func (q *Query) Preloader() (preloads []string, ok bool) {
|
||||
|
||||
if q.preloader == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
preloads = q.preloader.Preload()
|
||||
ok = len(preloads) > 0
|
||||
|
||||
return
|
||||
return preloads, len(preloads) > 0
|
||||
}
|
||||
|
||||
// OnDecode callback is called after the mongo decode function
|
||||
func (q *Query) OnDecode() (callbacks Callbacks) {
|
||||
return q.ondecode
|
||||
return q.onDecode
|
||||
}
|
||||
|
||||
// OnClose callback is called after the mongox ends a loading procedure
|
||||
func (q *Query) OnClose() (callbacks Callbacks) {
|
||||
return q.onclose
|
||||
return q.onClose
|
||||
}
|
||||
|
||||
// OnCreate callback is called if the mongox creates a new document instance during loading
|
||||
func (q *Query) OnCreate() (callbacks Callbacks) {
|
||||
return q.onclose
|
||||
return q.onClose
|
||||
}
|
||||
|
||||
// Empty checks the query for any content
|
||||
@@ -138,3 +127,8 @@ func (q *Query) Empty() (isEmpty bool) {
|
||||
func (q *Query) M() (m primitive.M) {
|
||||
return q.m
|
||||
}
|
||||
|
||||
// New creates a new query
|
||||
func New() (query *Query) {
|
||||
return &Query{}
|
||||
}
|
||||
|
||||
@@ -12,13 +12,12 @@ var _ Skipper = Skip(0)
|
||||
|
||||
// Skip returns a skip number
|
||||
func (l Skip) Skip() (skip *int64) {
|
||||
|
||||
if l <= 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
skip = new(int64)
|
||||
*skip = int64(l)
|
||||
|
||||
return
|
||||
return skip
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user