parent
7c3e50e783
commit
23029ae710
@ -0,0 +1,44 @@ |
||||
package docbased |
||||
|
||||
import ( |
||||
"github.com/modern-go/reflect2" |
||||
"go.mongodb.org/mongo-driver/bson/primitive" |
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox" |
||||
) |
||||
|
||||
var _ mongox.DocBased = (*Primary)(nil) |
||||
|
||||
// Primary is a structure with object as an _id field
|
||||
type Primary struct { |
||||
ID primitive.D `bson:"_id" json:"_id"` |
||||
} |
||||
|
||||
// GetID returns an _id
|
||||
func (p *Primary) GetID() (id primitive.D) { |
||||
return p.ID |
||||
} |
||||
|
||||
// SetID sets an _id
|
||||
func (p *Primary) SetID(id primitive.D) { |
||||
p.ID = id |
||||
} |
||||
|
||||
// New creates a new Primary structure with a defined _id
|
||||
func New(e primitive.E, ee ...primitive.E) Primary { |
||||
id := primitive.D{e} |
||||
if len(ee) > 0 { |
||||
id = append(id, ee...) |
||||
} |
||||
|
||||
return Primary{ID: id} |
||||
} |
||||
|
||||
func GetID(source mongox.DocBased) (id primitive.D, err error) { |
||||
id = source.GetID() |
||||
if !reflect2.IsNil(id) { |
||||
return id, nil |
||||
} |
||||
|
||||
return nil, mongox.ErrUninitializedBase |
||||
} |
@ -1,40 +0,0 @@ |
||||
package base |
||||
|
||||
import ( |
||||
"reflect" |
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection" |
||||
) |
||||
|
||||
// GetProtection function finds protection field in the source document otherwise returns nil
|
||||
func GetProtection(source interface{}) (key *protection.Key) { |
||||
|
||||
v := reflect.ValueOf(source) |
||||
if v.Kind() != reflect.Ptr || v.IsNil() { |
||||
return |
||||
} |
||||
|
||||
el := v.Elem() |
||||
numField := el.NumField() |
||||
|
||||
for i := 0; i < numField; i++ { |
||||
field := el.Field(i) |
||||
if !field.CanInterface() { |
||||
continue |
||||
} |
||||
|
||||
switch field.Interface().(type) { |
||||
case *protection.Key: |
||||
key = field.Interface().(*protection.Key) |
||||
case protection.Key: |
||||
ptr := field.Addr() |
||||
key = ptr.Interface().(*protection.Key) |
||||
default: |
||||
continue |
||||
} |
||||
|
||||
return |
||||
} |
||||
|
||||
return |
||||
} |
@ -0,0 +1,16 @@ |
||||
package ifacebased |
||||
|
||||
import ( |
||||
"github.com/mainnika/mongox-go-driver/v2/mongox" |
||||
"github.com/modern-go/reflect2" |
||||
) |
||||
|
||||
// GetID returns an _id from the source document
|
||||
func GetID(source mongox.InterfaceBased) (id interface{}, err error) { |
||||
id = source.GetID() |
||||
if !reflect2.IsNil(id) { |
||||
return id, nil |
||||
} |
||||
|
||||
return nil, mongox.ErrUninitializedBase |
||||
} |
@ -1,24 +0,0 @@ |
||||
package jsonbased |
||||
|
||||
import ( |
||||
"go.mongodb.org/mongo-driver/bson/primitive" |
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox" |
||||
) |
||||
|
||||
var _ mongox.JSONBased = (*Primary)(nil) |
||||
|
||||
// Primary is a structure with object as an _id field
|
||||
type Primary struct { |
||||
ID primitive.D `bson:"_id" json:"_id"` |
||||
} |
||||
|
||||
// GetID returns an _id
|
||||
func (p *Primary) GetID() (id primitive.D) { |
||||
return p.ID |
||||
} |
||||
|
||||
// SetID sets an _id
|
||||
func (p *Primary) SetID(id primitive.D) { |
||||
p.ID = id |
||||
} |
@ -0,0 +1,3 @@ |
||||
package protection_test |
||||
|
||||
// TODO:
|
@ -0,0 +1,23 @@ |
||||
package database |
||||
|
||||
import ( |
||||
"context" |
||||
) |
||||
|
||||
type ctxDatabaseKey struct{} |
||||
|
||||
// GetFromContext function extracts the request data from context
|
||||
func GetFromContext(ctx context.Context) (q *Database, ok bool) { |
||||
q, ok = ctx.Value(ctxDatabaseKey{}).(*Database) |
||||
if !ok { |
||||
return nil, false |
||||
} |
||||
|
||||
return q, true |
||||
} |
||||
|
||||
// WithContext creates the new context with a database attached
|
||||
func WithContext(ctx context.Context, q *Database) (withQuery context.Context) { |
||||
db := NewDatabase(ctx, q.Client(), q.Name()) |
||||
return context.WithValue(ctx, ctxDatabaseKey{}, db) |
||||
} |
@ -0,0 +1,188 @@ |
||||
package database |
||||
|
||||
import ( |
||||
"fmt" |
||||
"reflect" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive" |
||||
"go.mongodb.org/mongo-driver/mongo/options" |
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox" |
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query" |
||||
) |
||||
|
||||
func (d *Database) createCursor(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) { |
||||
_, hasPreloader := composed.Preloader() |
||||
if hasPreloader { |
||||
return d.createAggregateCursor(target, composed) |
||||
} |
||||
|
||||
return d.createSimpleCursor(target, composed) |
||||
} |
||||
|
||||
func (d *Database) createSimpleCursor(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) { |
||||
collection, err := d.GetCollectionOf(target) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
opts := options.Find() |
||||
opts.Sort = composed.Sorter() |
||||
opts.Limit = composed.Limiter() |
||||
opts.Skip = composed.Skipper() |
||||
|
||||
ctx := d.Context() |
||||
m := composed.M() |
||||
|
||||
return collection.Find(ctx, m, opts) |
||||
} |
||||
|
||||
func (d *Database) createAggregateCursor(target interface{}, composed *query.Query) (cursor *mongox.Cursor, err error) { |
||||
collection, err := d.GetCollectionOf(target) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
pipeline := primitive.A{} |
||||
if !composed.Empty() { |
||||
pipeline = append(pipeline, primitive.M{"$match": composed.M()}) |
||||
} |
||||
if composed.Sorter() != nil { |
||||
pipeline = append(pipeline, primitive.M{"$sort": composed.Sorter()}) |
||||
} |
||||
if composed.Skipper() != nil { |
||||
pipeline = append(pipeline, primitive.M{"$skip": *composed.Skipper()}) |
||||
} |
||||
if composed.Limiter() != nil { |
||||
pipeline = append(pipeline, primitive.M{"$limit": *composed.Limiter()}) |
||||
} |
||||
|
||||
el := reflect.ValueOf(target) |
||||
elType := el.Type() |
||||
if elType.Kind() == reflect.Ptr { |
||||
elType = elType.Elem() |
||||
} |
||||
|
||||
numField := elType.NumField() |
||||
preloads, _ := composed.Preloader() |
||||
for i := 0; i < numField; i++ { |
||||
field := elType.Field(i) |
||||
tag := field.Tag |
||||
|
||||
preloadTag, ok := tag.Lookup("preload") |
||||
if !ok { |
||||
continue |
||||
} |
||||
jsonTag, _ := tag.Lookup("json") |
||||
if jsonTag == "-" { |
||||
return nil, fmt.Errorf("%w: private field is not preloadable", mongox.ErrMalformedBase) |
||||
} |
||||
|
||||
jsonData := strings.SplitN(jsonTag, ",", 2) |
||||
jsonName := field.Name |
||||
if len(jsonData) > 0 { |
||||
jsonName = strings.TrimSpace(jsonData[0]) |
||||
} |
||||
|
||||
preloadData := strings.Split(preloadTag, ",") |
||||
if len(preloadData) == 0 { |
||||
continue |
||||
} |
||||
if len(preloadData) == 1 { |
||||
return nil, fmt.Errorf("%w: foreign field is not specified", mongox.ErrMalformedBase) |
||||
} |
||||
|
||||
foreignField := strings.TrimSpace(preloadData[1]) |
||||
if len(foreignField) == 0 { |
||||
return nil, fmt.Errorf("%w: foreign field is empty", mongox.ErrMalformedBase) |
||||
} |
||||
localField := strings.TrimSpace(preloadData[0]) |
||||
if len(localField) == 0 { |
||||
localField = "_id" |
||||
} |
||||
|
||||
preloadLimiter := 100 |
||||
preloadReversed := false |
||||
if len(preloadData) > 2 { |
||||
stringLimit := strings.TrimSpace(preloadData[2]) |
||||
intLimit := preloadLimiter |
||||
|
||||
preloadReversed = strings.HasPrefix(stringLimit, "-") |
||||
if preloadReversed { |
||||
stringLimit = stringLimit[1:] |
||||
} |
||||
|
||||
intLimit, err = strconv.Atoi(stringLimit) |
||||
if err == nil { |
||||
preloadLimiter = intLimit |
||||
} else { |
||||
return nil, fmt.Errorf("%w: preload limit should be an integer", mongox.ErrMalformedBase) |
||||
} |
||||
} |
||||
|
||||
for _, preload := range preloads { |
||||
if preload != jsonName { |
||||
continue |
||||
} |
||||
|
||||
field := elType.Field(i) |
||||
fieldType := field.Type |
||||
|
||||
isSlice := fieldType.Kind() == reflect.Slice |
||||
if isSlice { |
||||
fieldType = fieldType.Elem() |
||||
} |
||||
|
||||
isPtr := fieldType.Kind() != reflect.Ptr |
||||
if isPtr { |
||||
return nil, fmt.Errorf("%w: preload field should have ptr type", mongox.ErrMalformedBase) |
||||
} |
||||
|
||||
lookupCollection, err := d.GetCollectionOf(reflect.Zero(fieldType).Interface()) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
lookupVars := primitive.M{"selector": "$" + localField} |
||||
lookupPipeline := primitive.A{ |
||||
primitive.M{"$match": primitive.M{"$expr": primitive.M{"$eq": primitive.A{"$" + foreignField, "$$selector"}}}}, |
||||
} |
||||
|
||||
if preloadReversed { |
||||
lookupPipeline = append(lookupPipeline, primitive.M{"$sort": primitive.M{"_id": -1}}) |
||||
} |
||||
if isSlice && preloadLimiter > 0 { |
||||
lookupPipeline = append(lookupPipeline, primitive.M{"$limit": preloadLimiter}) |
||||
} else if !isSlice { |
||||
lookupPipeline = append(lookupPipeline, primitive.M{"$limit": 1}) |
||||
} |
||||
|
||||
pipeline = append(pipeline, primitive.M{ |
||||
"$lookup": primitive.M{ |
||||
"from": lookupCollection.Name(), |
||||
"let": lookupVars, |
||||
"pipeline": lookupPipeline, |
||||
"as": jsonName, |
||||
}, |
||||
}) |
||||
|
||||
if isSlice { |
||||
continue |
||||
} |
||||
|
||||
pipeline = append(pipeline, primitive.M{ |
||||
"$unwind": primitive.M{ |
||||
"preserveNullAndEmptyArrays": true, |
||||
"path": "$" + jsonName, |
||||
}, |
||||
}) |
||||
} |
||||
} |
||||
|
||||
ctx := d.Context() |
||||
opts := options.Aggregate() |
||||
|
||||
return collection.Aggregate(ctx, pipeline, opts) |
||||
} |
@ -1,36 +1,25 @@ |
||||
package database |
||||
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox" |
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query" |
||||
) |
||||
|
||||
// LoadStream function loads documents one by one into a target channel
|
||||
func (d *Database) LoadStream(target interface{}, filters ...interface{}) (loader mongox.StreamLoader, err error) { |
||||
|
||||
composed, err := query.Compose(filters...) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
_, hasPreloader := composed.Preloader() |
||||
ctx := query.WithContext(d.Context(), composed) |
||||
|
||||
var cursor *mongox.Cursor |
||||
|
||||
if hasPreloader { |
||||
cursor, err = d.createAggregateLoad(target, composed) |
||||
} else { |
||||
cursor, err = d.createSimpleLoad(target, composed) |
||||
} |
||||
cur, err := d.createCursor(target, composed) |
||||
if err != nil { |
||||
err = fmt.Errorf("can't create find result: %w", err) |
||||
return |
||||
return nil, err |
||||
} |
||||
|
||||
loader = &StreamLoader{cur: cursor, ctx: ctx, query: composed} |
||||
loader = &StreamLoader{cur: cur, ctx: ctx, query: composed} |
||||
|
||||
return |
||||
return loader, nil |
||||
} |
||||
|
@ -1,72 +1,64 @@ |
||||
package database |
||||
|
||||
import ( |
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection" |
||||
"github.com/modern-go/reflect2" |
||||
"go.mongodb.org/mongo-driver/bson/primitive" |
||||
"go.mongodb.org/mongo-driver/mongo/options" |
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base" |
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query" |
||||
) |
||||
|
||||
// UpdateOne updates a single document in the database and loads it into target
|
||||
func (d *Database) UpdateOne(target interface{}, filters ...interface{}) (err error) { |
||||
|
||||
composed, err := query.Compose(filters...) |
||||
if err != nil { |
||||
return |
||||
return err |
||||
} |
||||
|
||||
updaterDoc, err := composed.Updater() |
||||
update, err := composed.Updater() |
||||
if err != nil { |
||||
return |
||||
return err |
||||
} |
||||
|
||||
collection := d.GetCollectionOf(target) |
||||
protected := base.GetProtection(target) |
||||
ctx := query.WithContext(d.Context(), composed) |
||||
|
||||
opts := options.FindOneAndUpdate() |
||||
opts.SetReturnDocument(options.After) |
||||
|
||||
protected := protection.Get(target) |
||||
if protected != nil { |
||||
if !protected.X.IsZero() { |
||||
query.Push(composed, protected) |
||||
} |
||||
|
||||
protected.Restate() |
||||
|
||||
setCmd, _ := updaterDoc["$set"].(primitive.M) |
||||
setCmd, _ := update["$set"].(primitive.M) |
||||
if reflect2.IsNil(setCmd) { |
||||
setCmd = primitive.M{} |
||||
} |
||||
protected.PutToDocument(setCmd) |
||||
updaterDoc["$set"] = setCmd |
||||
protected.Inject(setCmd) |
||||
update["$set"] = setCmd |
||||
} |
||||
|
||||
defer func() { |
||||
invokerr := composed.OnClose().Invoke(ctx, target) |
||||
if err == nil { |
||||
err = invokerr |
||||
} |
||||
collection, err := d.GetCollectionOf(target) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
ctx := query.WithContext(d.Context(), composed) |
||||
m := composed.M() |
||||
opts := options.FindOneAndUpdate() |
||||
opts.SetReturnDocument(options.After) |
||||
|
||||
return |
||||
}() |
||||
defer func() { _ = composed.OnClose().Invoke(ctx, target) }() |
||||
|
||||
result := collection.FindOneAndUpdate(ctx, composed.M(), updaterDoc, opts) |
||||
result := collection.FindOneAndUpdate(ctx, m, update, opts) |
||||
if result.Err() != nil { |
||||
return result.Err() |
||||
} |
||||
|
||||
err = result.Decode(target) |
||||
if err != nil { |
||||
return |
||||
return err |
||||
} |
||||
|
||||
err = composed.OnDecode().Invoke(ctx, target) |
||||
if err != nil { |
||||
return |
||||
} |
||||
_ = composed.OnDecode().Invoke(ctx, target) |
||||
|
||||
return |
||||
return nil |
||||
} |
||||
|
@ -0,0 +1,58 @@ |
||||
package query_test |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/mock" |
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query" |
||||
) |
||||
|
||||
func TestCallbacks_InvokeOk(t *testing.T) { |
||||
|
||||
mocked := mock.Mock{} |
||||
|
||||
var callbacks = query.Callbacks{ |
||||
func(ctx context.Context, iter interface{}) (err error) { |
||||
return mocked.Called(ctx, iter).Error(0) |
||||
}, |
||||
func(ctx context.Context, iter interface{}) (err error) { |
||||
return mocked.Called(ctx, iter).Error(0) |
||||
}, |
||||
} |
||||
|
||||
ctx := context.Background() |
||||
iter := int64(42) |
||||
|
||||
mocked.On("func1", ctx, iter).Return(nil).Once() |
||||
mocked.On("func2", ctx, iter).Return(nil).Once() |
||||
|
||||
assert.NoError(t, callbacks.Invoke(ctx, iter)) |
||||
assert.True(t, mocked.AssertExpectations(t)) |
||||
} |
||||
|
||||
func TestCallbacks_InvokeStopIfError(t *testing.T) { |
||||
|
||||
mocked := mock.Mock{} |
||||
|
||||
var callbacks = query.Callbacks{ |
||||
func(ctx context.Context, iter interface{}) (err error) { |
||||
return mocked.Called(ctx, iter).Error(0) |
||||
}, |
||||
func(ctx context.Context, iter interface{}) (err error) { |
||||
t.FailNow() |
||||
return |
||||
}, |
||||
} |
||||
|
||||
ctx := context.Background() |
||||
iter := int(42) |
||||
|
||||
mocked.On("func1", ctx, iter).Return(fmt.Errorf("wat")) |
||||
|
||||
assert.EqualError(t, callbacks.Invoke(ctx, iter), "wat") |
||||
assert.True(t, mocked.AssertExpectations(t)) |
||||
} |
@ -0,0 +1,128 @@ |
||||
package query_test |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"go.mongodb.org/mongo-driver/bson/primitive" |
||||
|
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/base/protection" |
||||
"github.com/mainnika/mongox-go-driver/v2/mongox/query" |
||||
) |
||||
|
||||
func TestPushBSON(t *testing.T) { |
||||
|
||||
q := &query.Query{} |
||||
|
||||
ok, err := query.Push(q, primitive.M{"foo": "bar"}) |
||||
|
||||
assert.True(t, ok) |
||||
assert.NoError(t, err) |
||||
assert.NotEmpty(t, q.M()) |
||||
assert.Len(t, q.M()["$and"], 1) |
||||
assert.Contains(t, q.M()["$and"], primitive.M{"foo": "bar"}) |
||||
|
||||
ok, err = query.Push(q, primitive.M{"bar": "foo"}) |
||||
|
||||
assert.True(t, ok) |
||||
assert.NoError(t, err) |
||||
assert.NotEmpty(t, q.M()) |
||||
assert.Len(t, q.M()["$and"], 2) |
||||
assert.Contains(t, q.M()["$and"], primitive.M{"foo": "bar"}) |
||||
assert.Contains(t, q.M()["$and"], primitive.M{"bar": "foo"}) |
||||
} |
||||
|
||||
func TestPushLimiter(t *testing.T) { |
||||
|
||||
q := &query.Query{} |
||||
lim := query.Limit(2) |
||||
|
||||
ok, err := query.Push(q, lim) |
||||
|
||||
assert.True(t, ok) |
||||
assert.NoError(t, err) |
||||
assert.NotNil(t, q.Limiter()) |
||||
assert.EqualValues(t, q.Limiter(), query.Limit(2).Limit()) |
||||
} |
||||
|
||||
func TestPushSorter(t *testing.T) { |
||||
|
||||
q := &query.Query{} |
||||
sort := query.Sort{"foo": 1} |
||||
|
||||
ok, err := query.Push(q, sort) |
||||
|
||||
assert.True(t, ok) |
||||
assert.NoError(t, err) |
||||
assert.NotNil(t, q.Sorter()) |
||||
assert.EqualValues(t, q.Sorter(), primitive.M{"foo": 1}) |
||||
} |
||||
|
||||
func TestPushSkipper(t *testing.T) { |
||||
|
||||
q := &query.Query{} |
||||
skip := query.Skip(66) |
||||
|
||||
ok, err := query.Push(q, skip) |
||||
|
||||
assert.True(t, ok) |
||||
assert.NoError(t, err) |
||||
assert.NotNil(t, q.Skipper()) |
||||
assert.EqualValues(t, q.Skipper(), query.Skip(66).Skip()) |
||||
} |
||||
|
||||
func TestPushProtection(t *testing.T) { |
||||
|
||||
t.Run("push protection key pointer", func(t *testing.T) { |
||||
q := &query.Query{} |
||||
protected := &protection.Key{V: 1, X: primitive.ObjectID{2}} |
||||
|
||||
ok, err := query.Push(q, protected) |
||||
|
||||
assert.True(t, ok) |
||||
assert.NoError(t, err) |
||||
assert.NotEmpty(t, q.M()["$and"]) |
||||
assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.ObjectID{2}, "_v": int64(1)}) |
||||
}) |
||||
|
||||
t.Run("push protection key struct", func(t *testing.T) { |
||||
q := &query.Query{} |
||||
protected := protection.Key{V: 1, X: primitive.ObjectID{2}} |
||||
|
||||
ok, err := query.Push(q, protected) |
||||
|
||||
assert.True(t, ok) |
||||
assert.NoError(t, err) |
||||
assert.NotEmpty(t, q.M()["$and"]) |
||||
assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.ObjectID{2}, "_v": int64(1)}) |
||||
}) |
||||
|
||||
t.Run("protection key is empty", func(t *testing.T) { |
||||
q := &query.Query{} |
||||
protected := &protection.Key{} |
||||
|
||||
ok, err := query.Push(q, protected) |
||||
|
||||
assert.True(t, ok) |
||||
assert.NoError(t, err) |
||||
assert.NotEmpty(t, q.M()["$and"]) |
||||
assert.Contains(t, q.M()["$and"], primitive.M{"_x": primitive.M{"$exists": false}, "_v": primitive.M{"$exists": false}}) |
||||
}) |
||||
} |
||||
|
||||
func TestPushPreloader(t *testing.T) { |
||||
|
||||
q := &query.Query{} |
||||
preloader := query.Preload{"a", "b"} |
||||
|
||||
ok, err := query.Push(q, preloader) |
||||
|
||||
assert.True(t, ok) |
||||
assert.NoError(t, err) |
||||
|
||||
p, hasPreloader := q.Preloader() |
||||
|
||||
assert.NotNil(t, p) |
||||
assert.True(t, hasPreloader) |
||||
assert.EqualValues(t, p, query.Preload{"a", "b"}) |
||||
} |
Loading…
Reference in new issue