diff --git a/datastore/integration_test.go b/datastore/integration_test.go index 3c94a63a92a1..318f2a4d0cdc 100644 --- a/datastore/integration_test.go +++ b/datastore/integration_test.go @@ -32,6 +32,7 @@ import ( "cloud.google.com/go/internal/testutil" "cloud.google.com/go/internal/uid" "cloud.google.com/go/rpcreplay" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/api/iterator" "google.golang.org/api/option" pb "google.golang.org/genproto/googleapis/datastore/v1" @@ -1240,6 +1241,98 @@ func TestIntegration_AggregationQueries(t *testing.T) { } +func TestIntegration_RunAggregationQueryWithOptions(t *testing.T) { + ctx := context.Background() + client := newTestClient(ctx, t) + defer client.Close() + + _, _, now, parent, cleanup := createTestEntities(ctx, t, client, "RunAggregationQueryWithOptions", 3) + defer cleanup() + + aggQuery := NewQuery("SQChild").Ancestor(parent).Filter("T=", now).NewAggregationQuery(). + WithSum("I", "i_sum").WithAvg("I", "i_avg").WithCount("count") + wantAggResult := map[string]interface{}{ + "i_sum": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: 6}}, + "i_avg": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: 2}}, + "count": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: 3}}, + } + + testCases := []struct { + desc string + wantFailure bool + wantErrMsg string + wantRes AggregationWithOptionsResult + opts []RunOption + }{ + { + desc: "No options", + wantRes: AggregationWithOptionsResult{ + Result: wantAggResult, + }, + }, + { + desc: "ExplainOptions.Analyze is false", + wantRes: AggregationWithOptionsResult{ + ExplainMetrics: &ExplainMetrics{ + PlanSummary: &PlanSummary{ + IndexesUsed: []*map[string]interface{}{ + { + "properties": "(T ASC, I ASC, __name__ ASC)", + "query_scope": "Includes ancestors", + }, + }, + }, + }, + }, + opts: []RunOption{ExplainOptions{}}, + }, + { + desc: "ExplainOptions.Analyze is true", + wantRes: AggregationWithOptionsResult{ + Result: wantAggResult, + ExplainMetrics: &ExplainMetrics{ + PlanSummary: &PlanSummary{ + IndexesUsed: []*map[string]interface{}{ + { + "properties": "(T ASC, I ASC, __name__ ASC)", + "query_scope": "Includes ancestors", + }, + }, + }, + ExecutionStats: &ExecutionStats{ + ReadOperations: 1, + ResultsReturned: 1, + DebugStats: &map[string]interface{}{ + "documents_scanned": "0", + "index_entries_scanned": "3", + }, + }, + }, + }, + opts: []RunOption{ExplainOptions{Analyze: true}}, + }, + } + + for _, testcase := range testCases { + testutil.Retry(t, 10, time.Second, func(r *testutil.R) { + gotRes, gotErr := client.RunAggregationQueryWithOptions(ctx, aggQuery, testcase.opts...) + if gotErr != nil { + r.Errorf("err: got %v, want: nil", gotErr) + } + + if gotErr == nil && !testutil.Equal(gotRes.Result, testcase.wantRes.Result, + cmpopts.IgnoreFields(ExplainMetrics{})) { + r.Errorf("%q: Mismatch in aggregation result got: %v, want: %v", testcase.desc, gotRes, testcase.wantRes) + return + } + + if err := cmpExplainMetrics(gotRes.ExplainMetrics, testcase.wantRes.ExplainMetrics); err != nil { + r.Errorf("%q: Mismatch in ExplainMetrics %+v", testcase.desc, err) + } + }) + } +} + type ckey struct{} func TestIntegration_LargeQuery(t *testing.T) { @@ -1556,6 +1649,189 @@ func TestIntegration_GetAllWithFieldMismatch(t *testing.T) { } } +func createTestEntities(ctx context.Context, t *testing.T, client *Client, partialNameKey string, count int) ([]*Key, []SQChild, int64, *Key, func()) { + parent := NameKey("SQParent", keyPrefix+partialNameKey+suffix, nil) + now := timeNow.Truncate(time.Millisecond).Unix() + + entities := []SQChild{} + for i := 0; i < count; i++ { + entities = append(entities, SQChild{I: i + 1, T: now, U: now, V: 1.5, W: "str"}) + } + + keys := make([]*Key, len(entities)) + for i := range keys { + keys[i] = IncompleteKey("SQChild", parent) + } + + // Create entities + keys, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("client.PutMulti: %v", err) + } + return keys, entities, now, parent, func() { + err := client.DeleteMulti(ctx, keys) + if err != nil { + t.Errorf("client.DeleteMulti: %v", err) + } + } +} + +type runWithOptionsTestcase struct { + desc string + wantKeys []*Key + wantExplainMetrics *ExplainMetrics + wantEntities []SQChild + opts []RunOption +} + +func getRunWithOptionsTestcases(ctx context.Context, t *testing.T, client *Client, partialNameKey string, count int) ([]runWithOptionsTestcase, int64, *Key, func()) { + keys, entities, now, parent, cleanup := createTestEntities(ctx, t, client, partialNameKey, count) + return []runWithOptionsTestcase{ + { + desc: "No ExplainOptions", + wantKeys: keys, + wantEntities: entities, + }, + { + desc: "ExplainOptions.Analyze is false", + opts: []RunOption{ExplainOptions{}}, + wantExplainMetrics: &ExplainMetrics{ + PlanSummary: &PlanSummary{ + IndexesUsed: []*map[string]interface{}{ + { + "properties": "(T ASC, I ASC, __name__ ASC)", + "query_scope": "Includes ancestors", + }, + }, + }, + }, + }, + { + desc: "ExplainOptions.Analyze is true", + opts: []RunOption{ExplainOptions{Analyze: true}}, + wantKeys: keys, + wantExplainMetrics: &ExplainMetrics{ + ExecutionStats: &ExecutionStats{ + ReadOperations: int64(count), + ResultsReturned: int64(count), + DebugStats: &map[string]interface{}{ + "documents_scanned": fmt.Sprint(count), + "index_entries_scanned": fmt.Sprint(count), + }, + }, + PlanSummary: &PlanSummary{ + IndexesUsed: []*map[string]interface{}{ + { + "properties": "(T ASC, I ASC, __name__ ASC)", + "query_scope": "Includes ancestors", + }, + }, + }, + }, + wantEntities: entities, + }, + }, now, parent, cleanup +} + +func TestIntegration_GetAllWithOptions(t *testing.T) { + ctx := context.Background() + client := newTestClient(ctx, t) + defer client.Close() + testcases, now, parent, cleanup := getRunWithOptionsTestcases(ctx, t, client, "GetAllWithOptions", 3) + defer cleanup() + query := NewQuery("SQChild").Ancestor(parent).Filter("T=", now).Order("I") + for _, testcase := range testcases { + var gotSQChildsFromGetAll []SQChild + gotRes, gotErr := client.GetAllWithOptions(ctx, query, &gotSQChildsFromGetAll, testcase.opts...) + if gotErr != nil { + t.Errorf("%v err: got: %+v, want: nil", testcase.desc, gotErr) + } + if !testutil.Equal(gotSQChildsFromGetAll, testcase.wantEntities) { + t.Errorf("%v entities: got: %+v, want: %+v", testcase.desc, gotSQChildsFromGetAll, testcase.wantEntities) + } + if !testutil.Equal(gotRes.Keys, testcase.wantKeys) { + t.Errorf("%v keys: got: %+v, want: %+v", testcase.desc, gotRes.Keys, testcase.wantKeys) + } + if err := cmpExplainMetrics(gotRes.ExplainMetrics, testcase.wantExplainMetrics); err != nil { + t.Errorf("%v %+v", testcase.desc, err) + } + } +} + +func TestIntegration_RunWithOptions(t *testing.T) { + ctx := context.Background() + client := newTestClient(ctx, t) + defer client.Close() + testcases, now, parent, cleanup := getRunWithOptionsTestcases(ctx, t, client, "RunWithOptions", 3) + defer cleanup() + query := NewQuery("SQChild").Ancestor(parent).Filter("T=", now).Order("I") + for _, testcase := range testcases { + var gotSQChildsFromRun []SQChild + iter := client.RunWithOptions(ctx, query, testcase.opts...) + for { + var gotSQChild SQChild + _, err := iter.Next(&gotSQChild) + if err == iterator.Done { + break + } + if err != nil { + t.Errorf("%v iter.Next: %v", testcase.desc, err) + } + gotSQChildsFromRun = append(gotSQChildsFromRun, gotSQChild) + } + if !testutil.Equal(gotSQChildsFromRun, testcase.wantEntities) { + t.Errorf("%v entities: got: %+v, want: %+v", testcase.desc, gotSQChildsFromRun, testcase.wantEntities) + } + + if err := cmpExplainMetrics(iter.ExplainMetrics, testcase.wantExplainMetrics); err != nil { + t.Errorf("%v %+v", testcase.desc, err) + } + } +} + +func cmpExplainMetrics(got *ExplainMetrics, want *ExplainMetrics) error { + if (got != nil && want == nil) || (got == nil && want != nil) { + return fmt.Errorf("ExplainMetrics: got: %+v, want: %+v", got, want) + } + if got == nil { + return nil + } + if !testutil.Equal(got.PlanSummary, want.PlanSummary) { + return fmt.Errorf("Plan: got: %+v, want: %+v", got.PlanSummary, want.PlanSummary) + } + if err := cmpExecutionStats(got.ExecutionStats, want.ExecutionStats); err != nil { + return err + } + return nil +} + +func cmpExecutionStats(got *ExecutionStats, want *ExecutionStats) error { + if (got != nil && want == nil) || (got == nil && want != nil) { + return fmt.Errorf("ExecutionStats: got: %+v, want: %+v", got, want) + } + if got == nil { + return nil + } + + // Compare all fields except DebugStats + if !testutil.Equal(want, got, cmpopts.IgnoreFields(ExecutionStats{}, "DebugStats", "ExecutionDuration")) { + return fmt.Errorf("ExecutionStats: mismatch (-want +got):\n%s", testutil.Diff(want, got, cmpopts.IgnoreFields(ExecutionStats{}, "DebugStats"))) + } + + // Compare DebugStats + gotDebugStats := *got.DebugStats + for wantK, wantV := range *want.DebugStats { + // ExecutionStats.Debugstats has some keys whose values cannot be predicted. So, those values have not been included in want + // Here, compare only those values included in want + gotV, ok := gotDebugStats[wantK] + if !ok || !testutil.Equal(gotV, wantV) { + return fmt.Errorf("ExecutionStats.DebugStats: wantKey: %v gotValue: %+v, wantValue: %+v", wantK, gotV, wantV) + } + } + + return nil +} + func TestIntegration_KindlessQueries(t *testing.T) { ctx := context.Background() client := newTestClient(ctx, t) diff --git a/datastore/query.go b/datastore/query.go index 929c133e6189..d4dec92a8159 100644 --- a/datastore/query.go +++ b/datastore/query.go @@ -23,7 +23,9 @@ import ( "reflect" "strconv" "strings" + "time" + "cloud.google.com/go/internal/protostruct" "cloud.google.com/go/internal/trace" "google.golang.org/api/iterator" pb "google.golang.org/genproto/googleapis/datastore/v1" @@ -627,6 +629,107 @@ func (c *Client) Count(ctx context.Context, q *Query) (n int, err error) { } } +// RunOption lets the user provide options while running a query +type RunOption interface { + apply(*runQuerySettings) error +} + +// ExplainOptions is explain options for the query. +// +// Query Explain feature is still in preview and not yet publicly available. +// Pre-GA features might have limited support and can change at any time. +type ExplainOptions struct { + // When false (the default), the query will be planned, returning only + // metrics from the planning stages. + // When true, the query will be planned and executed, returning the full + // query results along with both planning and execution stage metrics. + Analyze bool +} + +func (e ExplainOptions) apply(s *runQuerySettings) error { + if s.explainOptions != nil { + return errors.New("datastore: ExplainOptions can be specified only once") + } + pbExplainOptions := pb.ExplainOptions{ + Analyze: e.Analyze, + } + s.explainOptions = &pbExplainOptions + return nil +} + +type runQuerySettings struct { + explainOptions *pb.ExplainOptions +} + +// newRunQuerySettings creates a runQuerySettings with a given RunOption slice. +func newRunQuerySettings(opts []RunOption) (*runQuerySettings, error) { + s := &runQuerySettings{} + for _, o := range opts { + if o == nil { + return nil, errors.New("datastore: RunOption cannot be nil") + } + err := o.apply(s) + if err != nil { + return nil, err + } + } + return s, nil +} + +// ExplainMetrics for the query. +type ExplainMetrics struct { + // Planning phase information for the query. + PlanSummary *PlanSummary + + // Aggregated stats from the execution of the query. Only present when + // ExplainOptions.Analyze is set to true. + ExecutionStats *ExecutionStats +} + +// PlanSummary represents planning phase information for the query. +type PlanSummary struct { + // The indexes selected for the query. For example: + // + // [ + // {"query_scope": "Collection", "properties": "(foo ASC, __name__ ASC)"}, + // {"query_scope": "Collection", "properties": "(bar ASC, __name__ ASC)"} + // ] + IndexesUsed []*map[string]interface{} +} + +// ExecutionStats represents execution statistics for the query. +type ExecutionStats struct { + // Total number of results returned, including documents, projections, + // aggregation results, keys. + ResultsReturned int64 + // Total time to execute the query in the backend. + ExecutionDuration *time.Duration + // Total billable read operations. + ReadOperations int64 + // Debugging statistics from the execution of the query. Note that the + // debugging stats are subject to change as Firestore evolves. It could + // include: + // + // { + // "indexes_entries_scanned": "1000", + // "documents_scanned": "20", + // "billing_details" : { + // "documents_billable": "20", + // "index_entries_billable": "1000", + // "min_query_cost": "0" + // } + // } + DebugStats *map[string]interface{} +} + +// GetAllWithOptionsResult is the result of call to GetAllWithOptions method +type GetAllWithOptionsResult struct { + Keys []*Key + + // Query explain metrics. This is only present when ExplainOptions is provided. + ExplainMetrics *ExplainMetrics +} + // GetAll runs the provided query in the given context and returns all keys // that match that query, as well as appending the values to dst. // @@ -651,6 +754,15 @@ func (c *Client) GetAll(ctx context.Context, q *Query, dst interface{}) (keys [] ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Query.GetAll") defer func() { trace.EndSpan(ctx, err) }() + res, err := c.GetAllWithOptions(ctx, q, dst) + return res.Keys, err +} + +// GetAllWithOptions is similar to GetAll but runs the query with provided options +func (c *Client) GetAllWithOptions(ctx context.Context, q *Query, dst interface{}, opts ...RunOption) (res GetAllWithOptionsResult, err error) { + ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Query.GetAllWithOptions") + defer func() { trace.EndSpan(ctx, err) }() + var ( dv reflect.Value mat multiArgType @@ -660,22 +772,23 @@ func (c *Client) GetAll(ctx context.Context, q *Query, dst interface{}) (keys [] if !q.keysOnly { dv = reflect.ValueOf(dst) if dv.Kind() != reflect.Ptr || dv.IsNil() { - return nil, ErrInvalidEntityType + return res, ErrInvalidEntityType } dv = dv.Elem() mat, elemType = checkMultiArg(dv) if mat == multiArgTypeInvalid || mat == multiArgTypeInterface { - return nil, ErrInvalidEntityType + return res, ErrInvalidEntityType } } - for t := c.Run(ctx, q); ; { + for t := c.RunWithOptions(ctx, q, opts...); ; { k, e, err := t.next() + res.ExplainMetrics = t.ExplainMetrics if err == iterator.Done { break } if err != nil { - return keys, err + return res, err } if !q.keysOnly { ev := reflect.New(elemType) @@ -702,7 +815,7 @@ func (c *Client) GetAll(ctx context.Context, q *Query, dst interface{}) (keys [] // an ErrFieldMismatch is returned. errFieldMismatch = err } else { - return keys, err + return res, err } } if mat != multiArgTypeStructPtr { @@ -710,21 +823,27 @@ func (c *Client) GetAll(ctx context.Context, q *Query, dst interface{}) (keys [] } dv.Set(reflect.Append(dv, ev)) } - keys = append(keys, k) + res.Keys = append(res.Keys, k) } - return keys, errFieldMismatch + return res, errFieldMismatch } -// Run runs the given query in the given context. +// Run runs the given query in the given context func (c *Client) Run(ctx context.Context, q *Query) (it *Iterator) { ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Query.Run") defer func() { trace.EndSpan(ctx, it.err) }() - it = c.run(ctx, q) - return it + return c.run(ctx, q) } -// run runs the given query in the given context. -func (c *Client) run(ctx context.Context, q *Query) *Iterator { +// RunWithOptions runs the given query in the given context with the provided options +func (c *Client) RunWithOptions(ctx context.Context, q *Query, opts ...RunOption) (it *Iterator) { + ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Query.RunWithOptions") + defer func() { trace.EndSpan(ctx, it.err) }() + return c.run(ctx, q, opts...) +} + +// run runs the given query in the given context with the provided options +func (c *Client) run(ctx context.Context, q *Query, opts ...RunOption) *Iterator { if q.err != nil { return &Iterator{ctx: ctx, err: q.err} } @@ -750,6 +869,16 @@ func (c *Client) run(ctx context.Context, q *Query) *Iterator { } } + runSettings, err := newRunQuerySettings(opts) + if err != nil { + t.err = err + return t + } + + if runSettings.explainOptions != nil { + t.req.ExplainOptions = runSettings.explainOptions + } + if err := q.toRunQueryRequest(t.req); err != nil { t.err = err } @@ -760,22 +889,30 @@ func (c *Client) run(ctx context.Context, q *Query) *Iterator { func (c *Client) RunAggregationQuery(ctx context.Context, aq *AggregationQuery) (ar AggregationResult, err error) { ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Query.RunAggregationQuery") defer func() { trace.EndSpan(ctx, err) }() + aro, err := c.RunAggregationQueryWithOptions(ctx, aq) + return aro.Result, err +} + +// RunAggregationQueryWithOptions runs aggregation query (e.g. COUNT) with provided options and returns results from the service. +func (c *Client) RunAggregationQueryWithOptions(ctx context.Context, aq *AggregationQuery, opts ...RunOption) (ar AggregationWithOptionsResult, err error) { + ctx = trace.StartSpan(ctx, "cloud.google.com/go/datastore.Query.RunAggregationQueryWithOptions") + defer func() { trace.EndSpan(ctx, err) }() if aq == nil { - return nil, errors.New("datastore: aggregation query cannot be nil") + return ar, errors.New("datastore: aggregation query cannot be nil") } if aq.query == nil { - return nil, errors.New("datastore: aggregation query must include nested query") + return ar, errors.New("datastore: aggregation query must include nested query") } if len(aq.aggregationQueries) == 0 { - return nil, errors.New("datastore: aggregation query must contain one or more operators (e.g. count)") + return ar, errors.New("datastore: aggregation query must contain one or more operators (e.g. count)") } q, err := aq.query.toProto() if err != nil { - return nil, err + return ar, err } req := &pb.RunAggregationQueryRequest{ @@ -797,6 +934,14 @@ func (c *Client) RunAggregationQuery(ctx context.Context, aq *AggregationQuery) } } + runSettings, err := newRunQuerySettings(opts) + if err != nil { + return ar, err + } + if runSettings.explainOptions != nil { + req.ExplainOptions = runSettings.explainOptions + } + // Parse the read options. txn := aq.query.trans if txn != nil { @@ -805,27 +950,29 @@ func (c *Client) RunAggregationQuery(ctx context.Context, aq *AggregationQuery) req.ReadOptions, err = parseQueryReadOptions(aq.query.eventual, txn) if err != nil { - return nil, err + return ar, err } - res, err := c.client.RunAggregationQuery(ctx, req) + resp, err := c.client.RunAggregationQuery(ctx, req) if err != nil { - return nil, err + return ar, err } if txn != nil && txn.state == transactionStateNotStarted { - txn.setToInProgress(res.Transaction) + txn.setToInProgress(resp.Transaction) } - ar = make(AggregationResult) - - // TODO(developer): change batch parsing logic if other aggregations are supported. - for _, a := range res.Batch.AggregationResults { - for k, v := range a.AggregateProperties { - ar[k] = v + if req.ExplainOptions == nil || req.ExplainOptions.Analyze { + ar.Result = make(AggregationResult) + // TODO(developer): change batch parsing logic if other aggregations are supported. + for _, a := range resp.Batch.AggregationResults { + for k, v := range a.AggregateProperties { + ar.Result[k] = v + } } } + ar.ExplainMetrics = fromPbExplainMetrics(resp.GetExplainMetrics()) return ar, nil } @@ -890,6 +1037,9 @@ type Iterator struct { // entityCursor is the compiled cursor of the next result. entityCursor []byte + // Query explain metrics. This is only present when ExplainOptions is used. + ExplainMetrics *ExplainMetrics + // trans records the transaction in which the query was run trans *Transaction @@ -982,6 +1132,13 @@ func (t *Iterator) nextBatch() error { txn.setToInProgress(resp.Transaction) } + if t.req.ExplainOptions != nil && !t.req.ExplainOptions.Analyze { + // No results to process + t.limit = 0 + t.ExplainMetrics = fromPbExplainMetrics(resp.GetExplainMetrics()) + return nil + } + // Adjust any offset from skipped results. skip := resp.Batch.SkippedResults if skip < 0 { @@ -1021,9 +1178,56 @@ func (t *Iterator) nextBatch() error { t.pageCursor = resp.Batch.EndCursor t.results = resp.Batch.EntityResults + t.ExplainMetrics = fromPbExplainMetrics(resp.GetExplainMetrics()) return nil } +func fromPbExplainMetrics(pbExplainMetrics *pb.ExplainMetrics) *ExplainMetrics { + if pbExplainMetrics == nil { + return nil + } + explainMetrics := &ExplainMetrics{ + PlanSummary: fromPbPlanSummary(pbExplainMetrics.PlanSummary), + ExecutionStats: fromPbExecutionStats(pbExplainMetrics.ExecutionStats), + } + return explainMetrics +} + +func fromPbPlanSummary(pbPlanSummary *pb.PlanSummary) *PlanSummary { + if pbPlanSummary == nil { + return nil + } + + planSummary := &PlanSummary{} + indexesUsed := []*map[string]interface{}{} + for _, pbIndexUsed := range pbPlanSummary.GetIndexesUsed() { + indexUsed := protostruct.DecodeToMap(pbIndexUsed) + indexesUsed = append(indexesUsed, &indexUsed) + } + + planSummary.IndexesUsed = indexesUsed + return planSummary +} + +func fromPbExecutionStats(pbstats *pb.ExecutionStats) *ExecutionStats { + if pbstats == nil { + return nil + } + + executionStats := &ExecutionStats{ + ResultsReturned: pbstats.GetResultsReturned(), + ReadOperations: pbstats.GetReadOperations(), + } + + executionDuration := pbstats.GetExecutionDuration().AsDuration() + executionStats.ExecutionDuration = &executionDuration + + debugStats := protostruct.DecodeToMap(pbstats.GetDebugStats()) + executionStats.DebugStats = &debugStats + + return executionStats +} + // Cursor returns a cursor for the iterator's current location. func (t *Iterator) Cursor() (c Cursor, err error) { t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/datastore.Query.Cursor") @@ -1154,3 +1358,11 @@ func (aq *AggregationQuery) WithAvg(fieldName string, alias string) *Aggregation // AggregationResult contains the results of an aggregation query. type AggregationResult map[string]interface{} + +// AggregationWithOptionsResult contains the results of an aggregation query run with options. +type AggregationWithOptionsResult struct { + Result AggregationResult + + // Query explain metrics. This is only present when ExplainOptions is provided. + ExplainMetrics *ExplainMetrics +} diff --git a/datastore/query_test.go b/datastore/query_test.go index 4883e847acc4..9c6d60ab1162 100644 --- a/datastore/query_test.go +++ b/datastore/query_test.go @@ -20,6 +20,7 @@ import ( "fmt" "reflect" "sort" + "strings" "testing" "cloud.google.com/go/internal/testutil" @@ -885,6 +886,69 @@ func TestAggregationQueryIsNil(t *testing.T) { } } +func TestExplainOptionsApply(t *testing.T) { + pbExplainOptions := pb.ExplainOptions{ + Analyze: true, + } + for _, testcase := range []struct { + desc string + existingOptions *pb.ExplainOptions + newOptions ExplainOptions + wantErrMsg string + }{ + { + desc: "ExplainOptions specified multiple times", + existingOptions: &pbExplainOptions, + newOptions: ExplainOptions{ + Analyze: true, + }, + wantErrMsg: "ExplainOptions can be specified only once", + }, + { + desc: "ExplainOptions specified once", + existingOptions: nil, + newOptions: ExplainOptions{ + Analyze: true, + }, + }, + } { + gotErr := testcase.newOptions.apply(&runQuerySettings{explainOptions: testcase.existingOptions}) + if (gotErr == nil && testcase.wantErrMsg != "") || + (gotErr != nil && !strings.Contains(gotErr.Error(), testcase.wantErrMsg)) { + t.Errorf("%v: apply got: %v want: %v", testcase.desc, gotErr.Error(), testcase.wantErrMsg) + } + } +} + +func TestNewRunQuerySettings(t *testing.T) { + for _, testcase := range []struct { + desc string + opts []RunOption + wantErrMsg string + }{ + { + desc: "nil RunOption", + opts: []RunOption{ExplainOptions{Analyze: true}, nil}, + wantErrMsg: "cannot be nil", + }, + { + desc: "success RunOption", + opts: []RunOption{ExplainOptions{Analyze: true}}, + }, + { + desc: "ExplainOptions specified multiple times", + opts: []RunOption{ExplainOptions{Analyze: true}, ExplainOptions{Analyze: false}, ExplainOptions{Analyze: true}}, + wantErrMsg: "ExplainOptions can be specified only once", + }, + } { + _, gotErr := newRunQuerySettings(testcase.opts) + if (gotErr == nil && testcase.wantErrMsg != "") || + (gotErr != nil && !strings.Contains(gotErr.Error(), testcase.wantErrMsg)) { + t.Errorf("%v: newRunQuerySettings got: %v want: %v", testcase.desc, gotErr, testcase.wantErrMsg) + } + } +} + func TestValidateReadOptions(t *testing.T) { eventualInTxnErr := errEventualConsistencyTransaction