diff --git a/mongox/database/loadarray.go b/mongox/database/loadarray.go index d9b5139..79072a5 100644 --- a/mongox/database/loadarray.go +++ b/mongox/database/loadarray.go @@ -34,6 +34,7 @@ func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err er composed := query.Compose(filters...) zeroElem := reflect.Zero(targetSliceElemT) hasPreloader, _ := composed.Preloader() + ctx := query.WithContext(d.Context(), composed) var result *mongox.Cursor var i int @@ -48,9 +49,9 @@ func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err er return } - defer composed.OnClose().Invoke(d.Context(), target) + defer composed.OnClose().Invoke(ctx, target) - for i = 0; result.Next(d.Context()); { + for i = 0; result.Next(ctx); { var elem interface{} @@ -67,13 +68,13 @@ func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err er err = result.Decode(elem) } if err != nil { - _ = result.Close(d.Context()) + _ = result.Close(ctx) return } - err = composed.OnDecode().Invoke(d.Context(), elem) + err = composed.OnDecode().Invoke(ctx, elem) if err != nil { - _ = result.Close(d.Context()) + _ = result.Close(ctx) return } @@ -83,5 +84,5 @@ func (d *Database) LoadArray(target interface{}, filters ...interface{}) (err er targetSliceV = targetSliceV.Slice(0, i) targetV.Elem().Set(targetSliceV) - return result.Close(d.Context()) + return result.Close(ctx) } diff --git a/mongox/database/loadone.go b/mongox/database/loadone.go index 8bc1db5..3ae7027 100644 --- a/mongox/database/loadone.go +++ b/mongox/database/loadone.go @@ -13,6 +13,7 @@ func (d *Database) LoadOne(target interface{}, filters ...interface{}) (err erro composed := query.Compose(append(filters, query.Limit(1))...) hasPreloader, _ := composed.Preloader() + ctx := query.WithContext(d.Context(), composed) var result *mongox.Cursor @@ -25,9 +26,9 @@ func (d *Database) LoadOne(target interface{}, filters ...interface{}) (err erro return fmt.Errorf("can't create find result: %w", err) } - defer composed.OnClose().Invoke(d.Context(), target) + defer composed.OnClose().Invoke(ctx, target) - hasNext := result.Next(d.Context()) + hasNext := result.Next(ctx) if result.Err() != nil { return err } @@ -42,7 +43,7 @@ func (d *Database) LoadOne(target interface{}, filters ...interface{}) (err erro return } - err = composed.OnDecode().Invoke(d.Context(), target) + err = composed.OnDecode().Invoke(ctx, target) if err != nil { return } diff --git a/mongox/database/loadstream.go b/mongox/database/loadstream.go index ee55585..7a53145 100644 --- a/mongox/database/loadstream.go +++ b/mongox/database/loadstream.go @@ -14,6 +14,7 @@ func (d *Database) LoadStream(target interface{}, filters ...interface{}) (loade composed := query.Compose(filters...) hasPreloader, _ := composed.Preloader() + ctx := query.WithContext(d.Context(), composed) if hasPreloader { cursor, err = d.createAggregateLoad(target, composed) @@ -25,7 +26,7 @@ func (d *Database) LoadStream(target interface{}, filters ...interface{}) (loade return } - loader = &StreamLoader{cur: cursor, ctx: d.Context(), target: target, query: composed} + loader = &StreamLoader{cur: cursor, ctx: ctx, target: target, query: composed} return } diff --git a/mongox/query/context.go b/mongox/query/context.go new file mode 100644 index 0000000..07e5b8a --- /dev/null +++ b/mongox/query/context.go @@ -0,0 +1,19 @@ +package query + +import ( + "context" +) + +type ctxQueryKey struct{} + +// GetFromContext function extracts the request data from context +func GetFromContext(ctx context.Context) (q *Query, ok bool) { + q, ok = ctx.Value(ctxQueryKey{}).(*Query) + return +} + +// WithContext function creates the new context with request data +func WithContext(ctx context.Context, q *Query) (withQuery context.Context) { + withQuery = context.WithValue(ctx, ctxQueryKey{}, q) + return +}