Prepare for v3

This commit is contained in:
2023-06-10 00:44:20 +02:00
parent 7c3e50e783
commit 23029ae710
42 changed files with 880 additions and 719 deletions
+2 -2
View File
@@ -18,9 +18,9 @@ func (c Callbacks) Invoke(ctx context.Context, iter interface{}) (err error) {
for _, cb := range c {
err = cb(ctx, iter)
if err != nil {
return
return err
}
}
return
return nil
}
+58
View File
@@ -0,0 +1,58 @@
package query_test
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
)
func TestCallbacks_InvokeOk(t *testing.T) {
mocked := mock.Mock{}
var callbacks = query.Callbacks{
func(ctx context.Context, iter interface{}) (err error) {
return mocked.Called(ctx, iter).Error(0)
},
func(ctx context.Context, iter interface{}) (err error) {
return mocked.Called(ctx, iter).Error(0)
},
}
ctx := context.Background()
iter := int64(42)
mocked.On("func1", ctx, iter).Return(nil).Once()
mocked.On("func2", ctx, iter).Return(nil).Once()
assert.NoError(t, callbacks.Invoke(ctx, iter))
assert.True(t, mocked.AssertExpectations(t))
}
func TestCallbacks_InvokeStopIfError(t *testing.T) {
mocked := mock.Mock{}
var callbacks = query.Callbacks{
func(ctx context.Context, iter interface{}) (err error) {
return mocked.Called(ctx, iter).Error(0)
},
func(ctx context.Context, iter interface{}) (err error) {
t.FailNow()
return
},
}
ctx := context.Background()
iter := int(42)
mocked.On("func1", ctx, iter).Return(fmt.Errorf("wat"))
assert.EqualError(t, callbacks.Invoke(ctx, iter), "wat")
assert.True(t, mocked.AssertExpectations(t))
}
+16 -28
View File
@@ -9,13 +9,11 @@ import (
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection"
)
type applyFilterFunc = func(query *Query, filter interface{}) (ok bool)
type applyFilterFunc func(query *Query, filter interface{}) (ok bool)
// Compose is a function to compose filters into a single query
func Compose(filters ...interface{}) (query *Query, err error) {
query = &Query{}
for _, filter := range filters {
ok, err := Push(query, filter)
if err != nil {
@@ -26,25 +24,25 @@ func Compose(filters ...interface{}) (query *Query, err error) {
}
}
return
return query, nil
}
// Push applies single filter to a query
func Push(query *Query, filter interface{}) (ok bool, err error) {
ok = reflect2.IsNil(filter)
if ok {
return
emptyFilter := reflect2.IsNil(filter)
if emptyFilter {
return true, nil
}
validator, hasValidator := filter.(Validator)
if hasValidator {
err = validator.Validate()
}
if err != nil {
return
err := validator.Validate()
if err != nil {
return false, fmt.Errorf("while validating filter %v, %w", filter, err)
}
}
ok = false // true if at least one filter was applied
for _, applier := range []applyFilterFunc{
applyBson,
applyLimit,
@@ -58,12 +56,11 @@ func Push(query *Query, filter interface{}) (ok bool, err error) {
ok = applier(query, filter) || ok
}
return
return ok, nil
}
// applyBson is a fallback for a custom primitive.M
func applyBson(query *Query, filter interface{}) (ok bool) {
if filter, ok := filter.(primitive.M); ok {
query.And(filter)
return true
@@ -74,7 +71,6 @@ func applyBson(query *Query, filter interface{}) (ok bool) {
// applyLimits extends query with a limiter
func applyLimit(query *Query, filter interface{}) (ok bool) {
if filter, ok := filter.(Limiter); ok {
query.limiter = filter
return true
@@ -85,7 +81,6 @@ func applyLimit(query *Query, filter interface{}) (ok bool) {
// applySort extends query with a sort rule
func applySort(query *Query, filter interface{}) (ok bool) {
if filter, ok := filter.(Sorter); ok {
query.sorter = filter
return true
@@ -96,7 +91,6 @@ func applySort(query *Query, filter interface{}) (ok bool) {
// applySkip extends query with a skip number
func applySkip(query *Query, filter interface{}) (ok bool) {
if filter, ok := filter.(Skipper); ok {
query.skipper = filter
return true
@@ -106,15 +100,12 @@ func applySkip(query *Query, filter interface{}) (ok bool) {
}
func applyProtection(query *Query, filter interface{}) (ok bool) {
var keyDoc = primitive.M{}
keyDoc := primitive.M{}
switch filter := filter.(type) {
case protection.Key:
filter.PutToDocument(keyDoc)
filter.Inject(keyDoc)
case *protection.Key:
filter.PutToDocument(keyDoc)
filter.Inject(keyDoc)
default:
return false
}
@@ -125,7 +116,6 @@ func applyProtection(query *Query, filter interface{}) (ok bool) {
}
func applyPreloader(query *Query, filter interface{}) (ok bool) {
if filter, ok := filter.(Preloader); ok {
query.preloader = filter
return true
@@ -135,7 +125,6 @@ func applyPreloader(query *Query, filter interface{}) (ok bool) {
}
func applyUpdater(query *Query, filter interface{}) (ok bool) {
if filter, ok := filter.(Updater); ok {
query.updater = filter
return true
@@ -145,12 +134,11 @@ func applyUpdater(query *Query, filter interface{}) (ok bool) {
}
func applyCallbacks(query *Query, filter interface{}) (ok bool) {
switch callback := filter.(type) {
case OnDecode:
query.ondecode = append(query.ondecode, Callback(callback))
query.onDecode = append(query.onDecode, Callback(callback))
case OnClose:
query.onclose = append(query.onclose, Callback(callback))
query.onClose = append(query.onClose, Callback(callback))
default:
return false
}
+128
View File
@@ -0,0 +1,128 @@
package query_test
import (
"testing"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson/primitive"
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection"
"github.com/mainnika/mongox-go-driver/v2/mongox/query"
)
func TestPushBSON(t *testing.T) {
q := &query.Query{}
ok, err := query.Push(q, primitive.M{"foo": "bar"})
assert.True(t, ok)
assert.NoError(t, err)
assert.NotEmpty(t, q.M())
assert.Len(t, q.M()["$and"], 1)
assert.Contains(t, q.M()["$and"], primitive.M{"foo": "bar"})
ok, err = query.Push(q, primitive.M{"bar": "foo"})
assert.True(t, ok)
assert.NoError(t, err)
assert.NotEmpty(t, q.M())
assert.Len(t, q.M()["$and"], 2)
assert.Contains(t, q.M()["$and"], primitive.M{"foo": "bar"})
assert.Contains(t, q.M()["$and"], primitive.M{"bar": "foo"})
}
func TestPushLimiter(t *testing.T) {
q := &query.Query{}
lim := query.Limit(2)
ok, err := query.Push(q, lim)
assert.True(t, ok)
assert.NoError(t, err)
assert.NotNil(t, q.Limiter())
assert.EqualValues(t, q.Limiter(), query.Limit(2).Limit())
}
func TestPushSorter(t *testing.T) {
q := &query.Query{}
sort := query.Sort{"foo": 1}
ok, err := query.Push(q, sort)
assert.True(t, ok)
assert.NoError(t, err)
assert.NotNil(t, q.Sorter())
assert.EqualValues(t, q.Sorter(), primitive.M{"foo": 1})
}
func TestPushSkipper(t *testing.T) {
q := &query.Query{}
skip := query.Skip(66)
ok, err := query.Push(q, skip)
assert.True(t, ok)
assert.NoError(t, err)
assert.NotNil(t, q.Skipper())
assert.EqualValues(t, q.Skipper(), query.Skip(66).Skip())
}
func TestPushProtection(t *testing.T) {
t.Run("push protection key pointer", func(t *testing.T) {
q := &query.Query{}
protected := &protection.Key{V: 1, X: primitive.ObjectID{2}}
ok, err := query.Push(q, protected)
assert.True(t, ok)
assert.NoError(t, err)
assert.NotEmpty(t, q.M()["$and"])
assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.ObjectID{2}, "_v": int64(1)})
})
t.Run("push protection key struct", func(t *testing.T) {
q := &query.Query{}
protected := protection.Key{V: 1, X: primitive.ObjectID{2}}
ok, err := query.Push(q, protected)
assert.True(t, ok)
assert.NoError(t, err)
assert.NotEmpty(t, q.M()["$and"])
assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.ObjectID{2}, "_v": int64(1)})
})
t.Run("protection key is empty", func(t *testing.T) {
q := &query.Query{}
protected := &protection.Key{}
ok, err := query.Push(q, protected)
assert.True(t, ok)
assert.NoError(t, err)
assert.NotEmpty(t, q.M()["$and"])
assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.M{"$exists": false}, "_v": primitive.M{"$exists": false}})
})
}
func TestPushPreloader(t *testing.T) {
q := &query.Query{}
preloader := query.Preload{"a", "b"}
ok, err := query.Push(q, preloader)
assert.True(t, ok)
assert.NoError(t, err)
p, hasPreloader := q.Preloader()
assert.NotNil(t, p)
assert.True(t, hasPreloader)
assert.EqualValues(t, p, query.Preload{"a", "b"})
}
+6 -3
View File
@@ -9,11 +9,14 @@ type ctxQueryKey struct{}
// GetFromContext function extracts the request data from context
func GetFromContext(ctx context.Context) (q *Query, ok bool) {
q, ok = ctx.Value(ctxQueryKey{}).(*Query)
return
if !ok {
return nil, false
}
return q, true
}
// WithContext function creates the new context with request data
func WithContext(ctx context.Context, q *Query) (withQuery context.Context) {
withQuery = context.WithValue(ctx, ctxQueryKey{}, q)
return
return context.WithValue(ctx, ctxQueryKey{}, q)
}
+2 -3
View File
@@ -12,13 +12,12 @@ var _ Limiter = Limit(0)
// Limit returns a limit
func (l Limit) Limit() (limit *int64) {
if l <= 0 {
return
return nil
}
limit = new(int64)
*limit = int64(l)
return
return limit
}
+2 -2
View File
@@ -1,11 +1,11 @@
package query
// Preloader is a filter to skip the result
// Preloader is a filter to preload the result
type Preloader interface {
Preload() (preloads []string)
}
// Preload is a simple implementation of the Skipper filter
// Preload is a simple implementation of the Preloader filter
type Preload []string
var _ Preloader = Preload{}
+22 -28
View File
@@ -15,35 +15,31 @@ type Query struct {
skipper Skipper
preloader Preloader
updater Updater
ondecode Callbacks
onclose Callbacks
oncreate Callbacks
onDecode Callbacks
onClose Callbacks
onCreate Callbacks
}
// And function pushes the elem query to the $and array of the query
func (q *Query) And(elem primitive.M) (query *Query) {
if q.m == nil {
q.m = primitive.M{}
}
queries, exists := q.m["$and"].(primitive.A)
if !exists {
q.m["$and"] = primitive.A{elem}
return q
}
q.m["$and"] = append(queries, elem)
return q
}
// Limiter returns limiter value or nil
func (q *Query) Limiter() (limit *int64) {
if q.limiter == nil {
return
return nil
}
return q.limiter.Limit()
@@ -51,9 +47,8 @@ func (q *Query) Limiter() (limit *int64) {
// Sorter is a sort rule for a query
func (q *Query) Sorter() (sort interface{}) {
if q.sorter == nil {
return
return nil
}
return q.sorter.Sort()
@@ -61,9 +56,8 @@ func (q *Query) Sorter() (sort interface{}) {
// Skipper is a skipper for a query
func (q *Query) Skipper() (skip *int64) {
if q.skipper == nil {
return
return nil
}
return q.skipper.Skip()
@@ -71,62 +65,57 @@ func (q *Query) Skipper() (skip *int64) {
// Updater is an update command for a query
func (q *Query) Updater() (update primitive.M, err error) {
if q.updater == nil {
update = primitive.M{}
return
return primitive.M{}, nil
}
update = q.updater.Update()
if reflect2.IsNil(update) {
update = primitive.M{}
return
return primitive.M{}, nil
}
buffer := bytebufferpool.Get()
defer bytebufferpool.Put(buffer)
// convert update document to bson map values
buffer.Reset()
bsonBytes, err := bson.MarshalAppend(buffer.B, update)
if err != nil {
return
return primitive.M{}, err
}
update = primitive.M{}
update = primitive.M{} // reset update map and unmarshal bson bytes to it again
err = bson.Unmarshal(bsonBytes, update)
if err != nil {
return
return primitive.M{}, err
}
return
return update, nil
}
// Preloader is a preloader list for a query
func (q *Query) Preloader() (preloads []string, ok bool) {
if q.preloader == nil {
return nil, false
}
preloads = q.preloader.Preload()
ok = len(preloads) > 0
return
return preloads, len(preloads) > 0
}
// OnDecode callback is called after the mongo decode function
func (q *Query) OnDecode() (callbacks Callbacks) {
return q.ondecode
return q.onDecode
}
// OnClose callback is called after the mongox ends a loading procedure
func (q *Query) OnClose() (callbacks Callbacks) {
return q.onclose
return q.onClose
}
// OnCreate callback is called if the mongox creates a new document instance during loading
func (q *Query) OnCreate() (callbacks Callbacks) {
return q.onclose
return q.onClose
}
// Empty checks the query for any content
@@ -138,3 +127,8 @@ func (q *Query) Empty() (isEmpty bool) {
func (q *Query) M() (m primitive.M) {
return q.m
}
// New creates a new query
func New() (query *Query) {
return &Query{}
}
+2 -3
View File
@@ -12,13 +12,12 @@ var _ Skipper = Skip(0)
// Skip returns a skip number
func (l Skip) Skip() (skip *int64) {
if l <= 0 {
return
return nil
}
skip = new(int64)
*skip = int64(l)
return
return skip
}