diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Statement.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Statement.java
index ea6cdf3f65c..a89c7c048fc 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Statement.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Statement.java
@@ -178,6 +178,11 @@ public String getSql() {
return sql;
}
+ /** Returns a copy of this statement with the SQL string replaced by the given SQL string. */
+ public Statement withReplacedSql(String sql) {
+ return new Statement(sql, this.parameters, this.queryOptions);
+ }
+
/** Returns the {@link QueryOptions} that will be used with this {@link Statement}. */
public QueryOptions getQueryOptions() {
return queryOptions;
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractStatementParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractStatementParser.java
index d0c06fa1d9d..b45d444b744 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractStatementParser.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/AbstractStatementParser.java
@@ -16,9 +16,13 @@
package com.google.cloud.spanner.connection;
+import static com.google.cloud.spanner.connection.SimpleParser.isValidIdentifierChar;
+import static com.google.cloud.spanner.connection.StatementHintParser.convertHintsToOptions;
+
import com.google.api.core.InternalApi;
import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.ErrorCode;
+import com.google.cloud.spanner.Options.ReadQueryUpdateTransactionOption;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.Statement;
@@ -169,6 +173,7 @@ public static class ParsedStatement {
private final Statement statement;
private final String sqlWithoutComments;
private final boolean returningClause;
+ private final ReadQueryUpdateTransactionOption[] optionsFromHints;
private static ParsedStatement clientSideStatement(
ClientSideStatementImpl clientSideStatement,
@@ -182,15 +187,27 @@ private static ParsedStatement ddl(Statement statement, String sqlWithoutComment
}
private static ParsedStatement query(
- Statement statement, String sqlWithoutComments, QueryOptions defaultQueryOptions) {
+ Statement statement,
+ String sqlWithoutComments,
+ QueryOptions defaultQueryOptions,
+ ReadQueryUpdateTransactionOption[] optionsFromHints) {
return new ParsedStatement(
- StatementType.QUERY, null, statement, sqlWithoutComments, defaultQueryOptions, false);
+ StatementType.QUERY,
+ null,
+ statement,
+ sqlWithoutComments,
+ defaultQueryOptions,
+ false,
+ optionsFromHints);
}
private static ParsedStatement update(
- Statement statement, String sqlWithoutComments, boolean returningClause) {
+ Statement statement,
+ String sqlWithoutComments,
+ boolean returningClause,
+ ReadQueryUpdateTransactionOption[] optionsFromHints) {
return new ParsedStatement(
- StatementType.UPDATE, statement, sqlWithoutComments, returningClause);
+ StatementType.UPDATE, statement, sqlWithoutComments, returningClause, optionsFromHints);
}
private static ParsedStatement unknown(Statement statement, String sqlWithoutComments) {
@@ -208,18 +225,20 @@ private ParsedStatement(
this.statement = statement;
this.sqlWithoutComments = Preconditions.checkNotNull(sqlWithoutComments);
this.returningClause = false;
+ this.optionsFromHints = EMPTY_OPTIONS;
}
private ParsedStatement(
StatementType type,
Statement statement,
String sqlWithoutComments,
- boolean returningClause) {
- this(type, null, statement, sqlWithoutComments, null, returningClause);
+ boolean returningClause,
+ ReadQueryUpdateTransactionOption[] optionsFromHints) {
+ this(type, null, statement, sqlWithoutComments, null, returningClause, optionsFromHints);
}
private ParsedStatement(StatementType type, Statement statement, String sqlWithoutComments) {
- this(type, null, statement, sqlWithoutComments, null, false);
+ this(type, null, statement, sqlWithoutComments, null, false, EMPTY_OPTIONS);
}
private ParsedStatement(
@@ -228,33 +247,37 @@ private ParsedStatement(
Statement statement,
String sqlWithoutComments,
QueryOptions defaultQueryOptions,
- boolean returningClause) {
+ boolean returningClause,
+ ReadQueryUpdateTransactionOption[] optionsFromHints) {
Preconditions.checkNotNull(type);
this.type = type;
this.clientSideStatement = clientSideStatement;
this.statement = statement == null ? null : mergeQueryOptions(statement, defaultQueryOptions);
this.sqlWithoutComments = Preconditions.checkNotNull(sqlWithoutComments);
this.returningClause = returningClause;
+ this.optionsFromHints = optionsFromHints;
}
private ParsedStatement copy(Statement statement, QueryOptions defaultQueryOptions) {
return new ParsedStatement(
this.type,
this.clientSideStatement,
- statement,
+ statement.withReplacedSql(this.statement.getSql()),
this.sqlWithoutComments,
defaultQueryOptions,
- this.returningClause);
+ this.returningClause,
+ this.optionsFromHints);
}
private ParsedStatement forCache() {
return new ParsedStatement(
this.type,
this.clientSideStatement,
- null,
+ Statement.of(this.statement.getSql()),
this.sqlWithoutComments,
null,
- this.returningClause);
+ this.returningClause,
+ this.optionsFromHints);
}
@Override
@@ -287,6 +310,11 @@ public boolean hasReturningClause() {
return this.returningClause;
}
+ @InternalApi
+ public ReadQueryUpdateTransactionOption[] getOptionsFromHints() {
+ return this.optionsFromHints;
+ }
+
/**
* @return true if the statement is a query that will return a {@link
* com.google.cloud.spanner.ResultSet}.
@@ -480,14 +508,23 @@ ParsedStatement parse(Statement statement, QueryOptions defaultQueryOptions) {
}
private ParsedStatement internalParse(Statement statement, QueryOptions defaultQueryOptions) {
+ StatementHintParser statementHintParser =
+ new StatementHintParser(getDialect(), statement.getSql());
+ ReadQueryUpdateTransactionOption[] optionsFromHints = EMPTY_OPTIONS;
+ if (statementHintParser.hasStatementHints()
+ && !statementHintParser.getClientSideStatementHints().isEmpty()) {
+ statement =
+ statement.toBuilder().replace(statementHintParser.getSqlWithoutClientSideHints()).build();
+ optionsFromHints = convertHintsToOptions(statementHintParser.getClientSideStatementHints());
+ }
String sql = removeCommentsAndTrim(statement.getSql());
ClientSideStatementImpl client = parseClientSideStatement(sql);
if (client != null) {
return ParsedStatement.clientSideStatement(client, statement, sql);
} else if (isQuery(sql)) {
- return ParsedStatement.query(statement, sql, defaultQueryOptions);
+ return ParsedStatement.query(statement, sql, defaultQueryOptions, optionsFromHints);
} else if (isUpdateStatement(sql)) {
- return ParsedStatement.update(statement, sql, checkReturningClause(sql));
+ return ParsedStatement.update(statement, sql, checkReturningClause(sql), optionsFromHints);
} else if (isDdlStatement(sql)) {
return ParsedStatement.ddl(statement, sql);
}
@@ -621,6 +658,10 @@ public String removeCommentsAndTrim(String sql) {
/** Removes any statement hints at the beginning of the statement. */
abstract String removeStatementHint(String sql);
+ @VisibleForTesting
+ static final ReadQueryUpdateTransactionOption[] EMPTY_OPTIONS =
+ new ReadQueryUpdateTransactionOption[0];
+
/** Parameter information with positional parameters translated to named parameters. */
@InternalApi
public static class ParametersInfo {
@@ -697,9 +738,10 @@ public boolean checkReturningClause(String sql) {
return checkReturningClauseInternal(sql);
}
+ abstract Dialect getDialect();
+
/**
- * <<<<<<< HEAD Returns true if this dialect supports nested comments. ======= <<<<<<< HEAD
- * Returns true if this dialect supports nested comments. >>>>>>> main
+ * Returns true if this dialect supports nested comments.
*
*
* - This method should return false for dialects that consider this to be a valid comment:
@@ -757,18 +799,6 @@ public boolean checkReturningClause(String sql) {
/** Returns the query parameter prefix that should be used for this dialect. */
abstract String getQueryParameterPrefix();
- /**
- * Returns true for characters that can be used as the first character in unquoted identifiers.
- */
- boolean isValidIdentifierFirstChar(char c) {
- return Character.isLetter(c) || c == UNDERSCORE;
- }
-
- /** Returns true for characters that can be used in unquoted identifiers. */
- boolean isValidIdentifierChar(char c) {
- return isValidIdentifierFirstChar(c) || Character.isDigit(c) || c == DOLLAR;
- }
-
/** Reads a dollar-quoted string literal from position index in the given sql string. */
String parseDollarQuotedString(String sql, int index) {
// Look ahead to the next dollar sign (if any). Everything in between is the quote tag.
@@ -812,9 +842,9 @@ int skip(String sql, int currentIndex, @Nullable StringBuilder result) {
} else if (currentChar == HYPHEN
&& sql.length() > (currentIndex + 1)
&& sql.charAt(currentIndex + 1) == HYPHEN) {
- return skipSingleLineComment(sql, currentIndex, result);
+ return skipSingleLineComment(sql, /* prefixLength = */ 2, currentIndex, result);
} else if (currentChar == DASH && supportsHashSingleLineComments()) {
- return skipSingleLineComment(sql, currentIndex, result);
+ return skipSingleLineComment(sql, /* prefixLength = */ 1, currentIndex, result);
} else if (currentChar == SLASH
&& sql.length() > (currentIndex + 1)
&& sql.charAt(currentIndex + 1) == ASTERISK) {
@@ -826,44 +856,31 @@ int skip(String sql, int currentIndex, @Nullable StringBuilder result) {
}
/** Skips a single-line comment from startIndex and adds it to result if result is not null. */
- static int skipSingleLineComment(String sql, int startIndex, @Nullable StringBuilder result) {
- int endIndex = sql.indexOf('\n', startIndex + 2);
- if (endIndex == -1) {
- endIndex = sql.length();
- } else {
- // Include the newline character.
- endIndex++;
+ int skipSingleLineComment(
+ String sql, int prefixLength, int startIndex, @Nullable StringBuilder result) {
+ return skipSingleLineComment(getDialect(), sql, prefixLength, startIndex, result);
+ }
+
+ static int skipSingleLineComment(
+ Dialect dialect,
+ String sql,
+ int prefixLength,
+ int startIndex,
+ @Nullable StringBuilder result) {
+ SimpleParser simpleParser = new SimpleParser(dialect, sql, startIndex, false);
+ if (simpleParser.skipSingleLineComment(prefixLength)) {
+ appendIfNotNull(result, sql.substring(startIndex, simpleParser.getPos()));
}
- appendIfNotNull(result, sql.substring(startIndex, endIndex));
- return endIndex;
+ return simpleParser.getPos();
}
/** Skips a multi-line comment from startIndex and adds it to result if result is not null. */
int skipMultiLineComment(String sql, int startIndex, @Nullable StringBuilder result) {
- // Current position is start + '/*'.length().
- int pos = startIndex + 2;
- // PostgreSQL allows comments to be nested. That is, the following is allowed:
- // '/* test /* inner comment */ still a comment */'
- int level = 1;
- while (pos < sql.length()) {
- if (supportsNestedComments()
- && sql.charAt(pos) == SLASH
- && sql.length() > (pos + 1)
- && sql.charAt(pos + 1) == ASTERISK) {
- level++;
- }
- if (sql.charAt(pos) == ASTERISK && sql.length() > (pos + 1) && sql.charAt(pos + 1) == SLASH) {
- level--;
- if (level == 0) {
- pos += 2;
- appendIfNotNull(result, sql.substring(startIndex, pos));
- return pos;
- }
- }
- pos++;
+ SimpleParser simpleParser = new SimpleParser(getDialect(), sql, startIndex, false);
+ if (simpleParser.skipMultiLineComment()) {
+ appendIfNotNull(result, sql.substring(startIndex, simpleParser.getPos()));
}
- appendIfNotNull(result, sql.substring(startIndex));
- return sql.length();
+ return simpleParser.getPos();
}
/** Skips a quoted string from startIndex. */
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java
index 1586c2dcb9e..3f33afcd873 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java
@@ -32,6 +32,7 @@
import com.google.cloud.spanner.Mutation;
import com.google.cloud.spanner.Options;
import com.google.cloud.spanner.Options.QueryOption;
+import com.google.cloud.spanner.Options.ReadQueryUpdateTransactionOption;
import com.google.cloud.spanner.Options.RpcPriority;
import com.google.cloud.spanner.Options.UpdateOption;
import com.google.cloud.spanner.PartitionOptions;
@@ -1154,6 +1155,7 @@ public ResultSet partitionQuery(
"Only queries can be partitioned. Invalid statement: " + query.getSql());
}
+ QueryOption[] combinedOptions = concat(parsedStatement.getOptionsFromHints(), options);
UnitOfWork transaction = getCurrentUnitOfWorkOrStartNewUnitOfWork();
return get(
transaction.partitionQueryAsync(
@@ -1161,7 +1163,8 @@ public ResultSet partitionQuery(
parsedStatement,
getEffectivePartitionOptions(partitionOptions),
mergeDataBoost(
- mergeQueryRequestOptions(parsedStatement, mergeQueryStatementTag(options)))));
+ mergeQueryRequestOptions(
+ parsedStatement, mergeQueryStatementTag(combinedOptions)))));
}
private PartitionOptions getEffectivePartitionOptions(
@@ -1455,6 +1458,34 @@ private List parseUpdateStatements(Iterable updates)
return parsedStatements;
}
+ private UpdateOption[] concat(
+ ReadQueryUpdateTransactionOption[] statementOptions, UpdateOption[] argumentOptions) {
+ if (statementOptions == null || statementOptions.length == 0) {
+ return argumentOptions;
+ }
+ if (argumentOptions == null || argumentOptions.length == 0) {
+ return statementOptions;
+ }
+ UpdateOption[] result =
+ Arrays.copyOf(statementOptions, statementOptions.length + argumentOptions.length);
+ System.arraycopy(argumentOptions, 0, result, statementOptions.length, argumentOptions.length);
+ return result;
+ }
+
+ private QueryOption[] concat(
+ ReadQueryUpdateTransactionOption[] statementOptions, QueryOption[] argumentOptions) {
+ if (statementOptions == null || statementOptions.length == 0) {
+ return argumentOptions;
+ }
+ if (argumentOptions == null || argumentOptions.length == 0) {
+ return statementOptions;
+ }
+ QueryOption[] result =
+ Arrays.copyOf(statementOptions, statementOptions.length + argumentOptions.length);
+ System.arraycopy(argumentOptions, 0, result, statementOptions.length, argumentOptions.length);
+ return result;
+ }
+
private QueryOption[] mergeDataBoost(QueryOption... options) {
if (this.dataBoostEnabled) {
options = appendQueryOption(options, Options.dataBoostEnabled(true));
@@ -1531,19 +1562,20 @@ private ResultSet internalExecuteQuery(
&& (analyzeMode != AnalyzeMode.NONE || statement.hasReturningClause())),
"Statement must either be a query or a DML mode with analyzeMode!=NONE or returning clause");
boolean isInternalMetadataQuery = isInternalMetadataQuery(options);
+ QueryOption[] combinedOptions = concat(statement.getOptionsFromHints(), options);
UnitOfWork transaction = getCurrentUnitOfWorkOrStartNewUnitOfWork(isInternalMetadataQuery);
if (autoPartitionMode
&& statement.getType() == StatementType.QUERY
&& !isInternalMetadataQuery) {
return runPartitionedQuery(
- statement.getStatement(), PartitionOptions.getDefaultInstance(), options);
+ statement.getStatement(), PartitionOptions.getDefaultInstance(), combinedOptions);
}
return get(
transaction.executeQueryAsync(
callType,
statement,
analyzeMode,
- mergeQueryRequestOptions(statement, mergeQueryStatementTag(options))));
+ mergeQueryRequestOptions(statement, mergeQueryStatementTag(combinedOptions))));
}
private AsyncResultSet internalExecuteQueryAsync(
@@ -1558,25 +1590,27 @@ private AsyncResultSet internalExecuteQueryAsync(
ConnectionPreconditions.checkState(
!(autoPartitionMode && statement.getType() == StatementType.QUERY),
"Partitioned queries cannot be executed asynchronously");
- UnitOfWork transaction =
- getCurrentUnitOfWorkOrStartNewUnitOfWork(isInternalMetadataQuery(options));
+ boolean isInternalMetadataQuery = isInternalMetadataQuery(options);
+ QueryOption[] combinedOptions = concat(statement.getOptionsFromHints(), options);
+ UnitOfWork transaction = getCurrentUnitOfWorkOrStartNewUnitOfWork(isInternalMetadataQuery);
return ResultSets.toAsyncResultSet(
transaction.executeQueryAsync(
callType,
statement,
analyzeMode,
- mergeQueryRequestOptions(statement, mergeQueryStatementTag(options))),
+ mergeQueryRequestOptions(statement, mergeQueryStatementTag(combinedOptions))),
spanner.getAsyncExecutorProvider(),
- options);
+ combinedOptions);
}
private ApiFuture internalExecuteUpdateAsync(
final CallType callType, final ParsedStatement update, UpdateOption... options) {
Preconditions.checkArgument(
update.getType() == StatementType.UPDATE, "Statement must be an update");
+ UpdateOption[] combinedOptions = concat(update.getOptionsFromHints(), options);
UnitOfWork transaction = getCurrentUnitOfWorkOrStartNewUnitOfWork();
return transaction.executeUpdateAsync(
- callType, update, mergeUpdateRequestOptions(mergeUpdateStatementTag(options)));
+ callType, update, mergeUpdateRequestOptions(mergeUpdateStatementTag(combinedOptions)));
}
private ApiFuture internalAnalyzeUpdateAsync(
@@ -1586,16 +1620,22 @@ private ApiFuture internalAnalyzeUpdateAsync(
UpdateOption... options) {
Preconditions.checkArgument(
update.getType() == StatementType.UPDATE, "Statement must be an update");
+ UpdateOption[] combinedOptions = concat(update.getOptionsFromHints(), options);
UnitOfWork transaction = getCurrentUnitOfWorkOrStartNewUnitOfWork();
return transaction.analyzeUpdateAsync(
- callType, update, analyzeMode, mergeUpdateRequestOptions(mergeUpdateStatementTag(options)));
+ callType,
+ update,
+ analyzeMode,
+ mergeUpdateRequestOptions(mergeUpdateStatementTag(combinedOptions)));
}
private ApiFuture internalExecuteBatchUpdateAsync(
CallType callType, List updates, UpdateOption... options) {
+ UpdateOption[] combinedOptions =
+ updates.isEmpty() ? options : concat(updates.get(0).getOptionsFromHints(), options);
UnitOfWork transaction = getCurrentUnitOfWorkOrStartNewUnitOfWork();
return transaction.executeBatchUpdateAsync(
- callType, updates, mergeUpdateRequestOptions(mergeUpdateStatementTag(options)));
+ callType, updates, mergeUpdateRequestOptions(mergeUpdateStatementTag(combinedOptions)));
}
private UnitOfWork getCurrentUnitOfWorkOrStartNewUnitOfWork() {
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java
index be4aa9d7f46..4f39c549de9 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/PostgreSQLStatementParser.java
@@ -16,6 +16,8 @@
package com.google.cloud.spanner.connection;
+import static com.google.cloud.spanner.connection.SimpleParser.isValidIdentifierFirstChar;
+
import com.google.api.core.InternalApi;
import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.ErrorCode;
@@ -39,6 +41,11 @@ public class PostgreSQLStatementParser extends AbstractStatementParser {
ClientSideStatements.getInstance(Dialect.POSTGRESQL).getCompiledStatements()));
}
+ @Override
+ Dialect getDialect() {
+ return Dialect.POSTGRESQL;
+ }
+
/**
* Indicates whether the parser supports the {@code EXPLAIN} clause. The PostgreSQL parser does
* not support it.
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SimpleParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SimpleParser.java
new file mode 100644
index 00000000000..0af86892dde
--- /dev/null
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SimpleParser.java
@@ -0,0 +1,303 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.connection;
+
+import static com.google.cloud.spanner.connection.AbstractStatementParser.ASTERISK;
+import static com.google.cloud.spanner.connection.AbstractStatementParser.DASH;
+import static com.google.cloud.spanner.connection.AbstractStatementParser.HYPHEN;
+import static com.google.cloud.spanner.connection.AbstractStatementParser.SLASH;
+
+import com.google.cloud.spanner.Dialect;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import java.util.Objects;
+
+/** A very simple token-based parser for extracting relevant information from SQL strings. */
+class SimpleParser {
+ /**
+ * An immutable result from a parse action indicating whether the parse action was successful, and
+ * if so, what the value was.
+ */
+ static class Result {
+ static final Result NOT_FOUND = new Result(null);
+
+ static Result found(String value) {
+ return new Result(Preconditions.checkNotNull(value));
+ }
+
+ private final String value;
+
+ private Result(String value) {
+ this.value = value;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(this.value);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof Result)) {
+ return false;
+ }
+ return Objects.equals(this.value, ((Result) o).value);
+ }
+
+ @Override
+ public String toString() {
+ if (isValid()) {
+ return this.value;
+ }
+ return "NOT FOUND";
+ }
+
+ boolean isValid() {
+ return this.value != null;
+ }
+
+ String getValue() {
+ return this.value;
+ }
+ }
+
+ // TODO: Replace this with a direct reference to the dialect, and move the isXYZSupported methods
+ // from the AbstractStatementParser class to the Dialect class.
+ private final AbstractStatementParser statementParser;
+
+ private final String sql;
+
+ private final boolean treatHintCommentsAsTokens;
+
+ private int pos;
+
+ /** Constructs a simple parser for the given SQL string and dialect. */
+ SimpleParser(Dialect dialect, String sql) {
+ this(dialect, sql, 0, /* treatHintCommentsAsTokens = */ false);
+ }
+
+ /**
+ * Constructs a simple parser for the given SQL string and dialect.
+ * treatHintCommentsAsTokens
indicates whether comments that start with '/*@' should be
+ * treated as tokens or not. This option may only be enabled if the dialect is PostgreSQL.
+ */
+ SimpleParser(Dialect dialect, String sql, int pos, boolean treatHintCommentsAsTokens) {
+ Preconditions.checkArgument(
+ !(treatHintCommentsAsTokens && dialect != Dialect.POSTGRESQL),
+ "treatHintCommentsAsTokens can only be enabled for PostgreSQL");
+ this.sql = sql;
+ this.pos = pos;
+ this.statementParser = AbstractStatementParser.getInstance(dialect);
+ this.treatHintCommentsAsTokens = treatHintCommentsAsTokens;
+ }
+
+ Dialect getDialect() {
+ return this.statementParser.getDialect();
+ }
+
+ String getSql() {
+ return this.sql;
+ }
+
+ int getPos() {
+ return this.pos;
+ }
+
+ /** Returns true if this parser has more tokens. Advances the position to the first next token. */
+ boolean hasMoreTokens() {
+ skipWhitespaces();
+ return pos < sql.length();
+ }
+
+ /**
+ * Eats and returns the identifier at the current position. This implementation does not support
+ * quoted identifiers.
+ */
+ Result eatIdentifier() {
+ // TODO: Implement support for quoted identifiers.
+ // TODO: Implement support for identifiers with multiple parts (e.g. my_schema.my_table).
+ if (!hasMoreTokens()) {
+ return Result.NOT_FOUND;
+ }
+ if (!isValidIdentifierFirstChar(sql.charAt(pos))) {
+ return Result.NOT_FOUND;
+ }
+ int startPos = pos;
+ while (pos < sql.length() && isValidIdentifierChar(sql.charAt(pos))) {
+ pos++;
+ }
+ return Result.found(sql.substring(startPos, pos));
+ }
+
+ /**
+ * Eats a single-quoted string. This implementation currently does not support escape sequences.
+ */
+ Result eatSingleQuotedString() {
+ if (!eatToken('\'')) {
+ return Result.NOT_FOUND;
+ }
+ int startPos = pos;
+ while (pos < sql.length() && sql.charAt(pos) != '\'') {
+ if (sql.charAt(pos) == '\n') {
+ return Result.NOT_FOUND;
+ }
+ pos++;
+ }
+ if (pos == sql.length()) {
+ return Result.NOT_FOUND;
+ }
+ return Result.found(sql.substring(startPos, pos++));
+ }
+
+ boolean peekTokens(char... tokens) {
+ return internalEatTokens(/* updatePos = */ false, tokens);
+ }
+
+ /**
+ * Returns true if the next tokens in the SQL string are equal to the given tokens, and advances
+ * the position of the parser to after the tokens. The position is not changed if the next tokens
+ * are not equal to the list of tokens.
+ */
+ boolean eatTokens(char... tokens) {
+ return internalEatTokens(/* updatePos = */ true, tokens);
+ }
+
+ /**
+ * Returns true if the next tokens in the SQL string are equal to the given tokens, and advances
+ * the position of the parser to after the tokens if updatePos is true. The position is not
+ * changed if the next tokens are not equal to the list of tokens, or if updatePos is false.
+ */
+ private boolean internalEatTokens(boolean updatePos, char... tokens) {
+ int currentPos = pos;
+ for (char token : tokens) {
+ if (!eatToken(token)) {
+ pos = currentPos;
+ return false;
+ }
+ }
+ if (!updatePos) {
+ pos = currentPos;
+ }
+ return true;
+ }
+
+ /**
+ * Returns true if the next token is equal to the given character, but does not advance the
+ * position of the parser.
+ */
+ boolean peekToken(char token) {
+ int currentPos = pos;
+ boolean res = eatToken(token);
+ pos = currentPos;
+ return res;
+ }
+
+ /**
+ * Returns true and advances the position of the parser if the next token is equal to the given
+ * character.
+ */
+ boolean eatToken(char token) {
+ skipWhitespaces();
+ if (pos < sql.length() && sql.charAt(pos) == token) {
+ pos++;
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * Returns true if the given character is valid as the first character of an identifier. That
+ * means that it can be used as the first character of an unquoted identifier.
+ */
+ static boolean isValidIdentifierFirstChar(char c) {
+ return Character.isLetter(c) || c == '_';
+ }
+
+ /**
+ * Returns true if the given character is a valid identifier character. That means that it can be
+ * used in an unquoted identifiers.
+ */
+ static boolean isValidIdentifierChar(char c) {
+ return isValidIdentifierFirstChar(c) || Character.isDigit(c) || c == '$';
+ }
+
+ /**
+ * Skips all whitespaces, including comments, from the current position and advances the parser to
+ * the next actual token.
+ */
+ @VisibleForTesting
+ void skipWhitespaces() {
+ while (pos < sql.length()) {
+ if (sql.charAt(pos) == HYPHEN && sql.length() > (pos + 1) && sql.charAt(pos + 1) == HYPHEN) {
+ skipSingleLineComment(/* prefixLength = */ 2);
+ } else if (statementParser.supportsHashSingleLineComments() && sql.charAt(pos) == DASH) {
+ skipSingleLineComment(/* prefixLength = */ 1);
+ } else if (sql.charAt(pos) == SLASH
+ && sql.length() > (pos + 1)
+ && sql.charAt(pos + 1) == ASTERISK) {
+ if (treatHintCommentsAsTokens && sql.length() > (pos + 2) && sql.charAt(pos + 2) == '@') {
+ break;
+ }
+ skipMultiLineComment();
+ } else if (Character.isWhitespace(sql.charAt(pos))) {
+ pos++;
+ } else {
+ break;
+ }
+ }
+ }
+
+ /**
+ * Skips through a single-line comment from the current position. The single-line comment is
+ * started by a prefix with the given length (e.g. either '#' or '--').
+ */
+ @VisibleForTesting
+ boolean skipSingleLineComment(int prefixLength) {
+ int endIndex = sql.indexOf('\n', pos + prefixLength);
+ if (endIndex == -1) {
+ pos = sql.length();
+ return true;
+ }
+ pos = endIndex + 1;
+ return true;
+ }
+
+ /** Skips through a multi-line comment from the current position. */
+ @VisibleForTesting
+ boolean skipMultiLineComment() {
+ int level = 1;
+ pos += 2;
+ while (pos < sql.length()) {
+ if (statementParser.supportsNestedComments()
+ && sql.charAt(pos) == SLASH
+ && sql.length() > (pos + 1)
+ && sql.charAt(pos + 1) == ASTERISK) {
+ level++;
+ }
+ if (sql.charAt(pos) == ASTERISK && sql.length() > (pos + 1) && sql.charAt(pos + 1) == SLASH) {
+ level--;
+ if (level == 0) {
+ pos += 2;
+ return true;
+ }
+ }
+ pos++;
+ }
+ pos = sql.length();
+ return false;
+ }
+}
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java
index 892672ad0df..fdd10bbf5ae 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/SpannerStatementParser.java
@@ -41,6 +41,11 @@ public SpannerStatementParser() throws CompileException {
ClientSideStatements.getInstance(Dialect.GOOGLE_STANDARD_SQL).getCompiledStatements()));
}
+ @Override
+ Dialect getDialect() {
+ return Dialect.GOOGLE_STANDARD_SQL;
+ }
+
/**
* Indicates whether the parser supports the {@code EXPLAIN} clause. The Spanner parser does
* support it.
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/StatementHintParser.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/StatementHintParser.java
new file mode 100644
index 00000000000..d6d4a7fa48c
--- /dev/null
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/StatementHintParser.java
@@ -0,0 +1,211 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.connection;
+
+import com.google.cloud.Tuple;
+import com.google.cloud.spanner.Dialect;
+import com.google.cloud.spanner.ErrorCode;
+import com.google.cloud.spanner.Options;
+import com.google.cloud.spanner.Options.ReadQueryUpdateTransactionOption;
+import com.google.cloud.spanner.Options.RpcPriority;
+import com.google.cloud.spanner.SpannerExceptionFactory;
+import com.google.cloud.spanner.connection.SimpleParser.Result;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.spanner.v1.RequestOptions.Priority;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Map.Entry;
+
+/** A simple parser for extracting statement hints from SQL strings. */
+class StatementHintParser {
+ private static final char[] GOOGLE_SQL_START_HINT_TOKENS = new char[] {'@', '{'};
+ private static final char[] POSTGRESQL_START_HINT_TOKENS = new char[] {'/', '*', '@'};
+ private static final char[] GOOGLE_SQL_END_HINT_TOKENS = new char[] {'}'};
+ private static final char[] POSTGRESQL_END_HINT_TOKENS = new char[] {'*', '/'};
+ private static final String STATEMENT_TAG_HINT_NAME = "STATEMENT_TAG";
+ private static final String RPC_PRIORITY_HINT_NAME = "RPC_PRIORITY";
+ private static final ImmutableSet CLIENT_SIDE_STATEMENT_HINT_NAMES =
+ ImmutableSet.of(STATEMENT_TAG_HINT_NAME, RPC_PRIORITY_HINT_NAME);
+
+ static final Map NO_HINTS = ImmutableMap.of();
+
+ private final boolean hasStatementHints;
+
+ private final Map hints;
+
+ private final String sqlWithoutClientSideHints;
+
+ StatementHintParser(Dialect dialect, String sql) {
+ this(CLIENT_SIDE_STATEMENT_HINT_NAMES, dialect, sql);
+ }
+
+ StatementHintParser(
+ ImmutableSet clientSideStatementHintNames, Dialect dialect, String sql) {
+ SimpleParser parser =
+ new SimpleParser(
+ dialect,
+ sql,
+ /* pos = */ 0,
+ /* treatHintCommentsAsTokens = */ dialect == Dialect.POSTGRESQL);
+ this.hasStatementHints = parser.peekTokens(getStartHintTokens(dialect));
+ if (this.hasStatementHints) {
+ Tuple> hints = extract(parser, clientSideStatementHintNames);
+ this.sqlWithoutClientSideHints = hints.x();
+ this.hints = hints.y();
+ } else {
+ this.sqlWithoutClientSideHints = sql;
+ this.hints = NO_HINTS;
+ }
+ }
+
+ private static char[] getStartHintTokens(Dialect dialect) {
+ switch (dialect) {
+ case POSTGRESQL:
+ return POSTGRESQL_START_HINT_TOKENS;
+ case GOOGLE_STANDARD_SQL:
+ default:
+ return GOOGLE_SQL_START_HINT_TOKENS;
+ }
+ }
+
+ private static char[] getEndHintTokens(Dialect dialect) {
+ switch (dialect) {
+ case POSTGRESQL:
+ return POSTGRESQL_END_HINT_TOKENS;
+ case GOOGLE_STANDARD_SQL:
+ default:
+ return GOOGLE_SQL_END_HINT_TOKENS;
+ }
+ }
+
+ /**
+ * Extracts any query/update options from client-side hints in the given statement. Currently,
+ * this method supports following client-side hints:
+ *
+ *
+ * - STATEMENT_TAG
+ *
- RPC_PRIORITY
+ *
+ */
+ static ReadQueryUpdateTransactionOption[] convertHintsToOptions(Map hints) {
+ ReadQueryUpdateTransactionOption[] result = new ReadQueryUpdateTransactionOption[hints.size()];
+ int index = 0;
+ for (Entry hint : hints.entrySet()) {
+ result[index++] = convertHintToOption(hint.getKey(), hint.getValue());
+ }
+ return result;
+ }
+
+ private static ReadQueryUpdateTransactionOption convertHintToOption(String hint, String value) {
+ Preconditions.checkNotNull(value);
+ switch (Preconditions.checkNotNull(hint).toUpperCase(Locale.ENGLISH)) {
+ case STATEMENT_TAG_HINT_NAME:
+ return Options.tag(value);
+ case RPC_PRIORITY_HINT_NAME:
+ try {
+ Priority priority = Priority.valueOf(value);
+ return Options.priority(RpcPriority.fromProto(priority));
+ } catch (IllegalArgumentException illegalArgumentException) {
+ throw SpannerExceptionFactory.newSpannerException(
+ ErrorCode.INVALID_ARGUMENT,
+ "Invalid RPC priority value: " + value,
+ illegalArgumentException);
+ }
+ default:
+ throw SpannerExceptionFactory.newSpannerException(
+ ErrorCode.INVALID_ARGUMENT, "Invalid hint name: " + hint);
+ }
+ }
+
+ boolean hasStatementHints() {
+ return this.hasStatementHints;
+ }
+
+ String getSqlWithoutClientSideHints() {
+ return this.sqlWithoutClientSideHints;
+ }
+
+ Map getClientSideStatementHints() {
+ return this.hints;
+ }
+
+ private static Tuple> extract(
+ SimpleParser parser, ImmutableSet clientSideStatementHintNames) {
+ String updatedSql = parser.getSql();
+ int posBeforeHintToken = parser.getPos();
+ int removedHintsLength = 0;
+ boolean allClientSideHints = true;
+ // This method is only called if the parser has hints, so it is safe to ignore this result.
+ parser.eatTokens(getStartHintTokens(parser.getDialect()));
+ ImmutableMap.Builder builder = ImmutableMap.builder();
+ while (parser.hasMoreTokens()) {
+ int posBeforeHint = parser.getPos();
+ boolean foundClientSideHint = false;
+ Result hintName = parser.eatIdentifier();
+ if (!hintName.isValid()) {
+ return Tuple.of(parser.getSql(), NO_HINTS);
+ }
+ if (!parser.eatToken('=')) {
+ return Tuple.of(parser.getSql(), NO_HINTS);
+ }
+ Result hintValue = eatHintLiteral(parser);
+ if (!hintValue.isValid()) {
+ return Tuple.of(parser.getSql(), NO_HINTS);
+ }
+ if (clientSideStatementHintNames.contains(hintName.getValue().toUpperCase(Locale.ENGLISH))) {
+ builder.put(hintName.getValue(), hintValue.getValue());
+ foundClientSideHint = true;
+ } else {
+ allClientSideHints = false;
+ }
+ boolean endOfHints = parser.peekTokens(getEndHintTokens(parser.getDialect()));
+ if (!endOfHints && !parser.eatToken(',')) {
+ return Tuple.of(parser.getSql(), NO_HINTS);
+ }
+ if (foundClientSideHint) {
+ // Remove the client-side hint from the SQL string that is sent to Spanner.
+ updatedSql =
+ updatedSql.substring(0, posBeforeHint - removedHintsLength)
+ + parser.getSql().substring(parser.getPos());
+ removedHintsLength += parser.getPos() - posBeforeHint;
+ }
+ if (endOfHints) {
+ break;
+ }
+ }
+ if (!parser.eatTokens(getEndHintTokens(parser.getDialect()))) {
+ return Tuple.of(parser.getSql(), NO_HINTS);
+ }
+ if (allClientSideHints) {
+ // Only client-side hints found. Remove the entire hint block.
+ updatedSql =
+ parser.getSql().substring(0, posBeforeHintToken)
+ + parser.getSql().substring(parser.getPos());
+ }
+ return Tuple.of(updatedSql, builder.build());
+ }
+
+ /** Eats a hint literal. This is a literal that could be a quoted string, or an identifier. */
+ private static Result eatHintLiteral(SimpleParser parser) {
+ if (parser.peekToken('\'')) {
+ return parser.eatSingleQuotedString();
+ }
+ return parser.eatIdentifier();
+ }
+}
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SimpleParserTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SimpleParserTest.java
new file mode 100644
index 00000000000..2f51e7d0443
--- /dev/null
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/SimpleParserTest.java
@@ -0,0 +1,224 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.connection;
+
+import static com.google.cloud.spanner.connection.SimpleParser.Result.NOT_FOUND;
+import static com.google.cloud.spanner.connection.SimpleParser.Result.found;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import com.google.cloud.spanner.Dialect;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class SimpleParserTest {
+
+ @Parameters(name = "dialect = {0}")
+ public static Object[] data() {
+ return Dialect.values();
+ }
+
+ @Parameter public Dialect dialect;
+
+ SimpleParser parserFor(String sql) {
+ return new SimpleParser(dialect, sql);
+ }
+
+ @Test
+ public void testResultHashCode() {
+ assertEquals(0, NOT_FOUND.hashCode());
+ assertEquals(found("foo").hashCode(), found("foo").hashCode());
+ assertNotEquals(found("foo").hashCode(), found("bar").hashCode());
+ assertNotEquals(NOT_FOUND.hashCode(), found("bar").hashCode());
+ }
+
+ @Test
+ public void testResultEquals() {
+ assertEquals(found("foo"), found("foo"));
+ assertNotEquals(found("foo"), found("bar"));
+ assertNotEquals(NOT_FOUND, found("bar"));
+ assertNotEquals(found("foo"), new Object());
+ assertNotEquals(NOT_FOUND, new Object());
+ }
+
+ @Test
+ public void testResultToString() {
+ assertEquals("foo", found("foo").toString());
+ assertEquals("NOT FOUND", NOT_FOUND.toString());
+ }
+
+ @Test
+ public void testResultGetValue() {
+ assertEquals("foo", found("foo").getValue());
+ assertNull(NOT_FOUND.getValue());
+ }
+
+ @Test
+ public void testEatToken() {
+ assertTrue(parserFor("(foo").eatToken('('));
+ assertTrue(parserFor("(").eatToken('('));
+ assertTrue(parserFor("( ").eatToken('('));
+ assertTrue(parserFor("\t( foo").eatToken('('));
+
+ assertFalse(parserFor("foo(").eatToken('('));
+ assertFalse(parserFor("").eatToken('('));
+ }
+
+ @Test
+ public void testEatTokenAdvancesPosition() {
+ SimpleParser parser = parserFor("@{test=value}");
+ assertEquals(0, parser.getPos());
+ assertTrue(parser.eatToken('@'));
+ assertEquals(1, parser.getPos());
+
+ assertFalse(parser.eatToken('('));
+ assertEquals(1, parser.getPos());
+
+ assertTrue(parser.eatToken('{'));
+ assertEquals(2, parser.getPos());
+ }
+
+ @Test
+ public void testEatTokensAdvancesPosition() {
+ SimpleParser parser = parserFor("@{test=value}");
+ assertEquals(0, parser.getPos());
+ assertTrue(parser.eatTokens('@', '{'));
+ assertEquals(2, parser.getPos());
+
+ assertFalse(parser.eatTokens('@', '{'));
+ assertEquals(2, parser.getPos());
+
+ parser = parserFor("@ /* comment */ { test=value}");
+ assertEquals(0, parser.getPos());
+ assertTrue(parser.eatTokens('@', '{'));
+ assertEquals("@ /* comment */ {".length(), parser.getPos());
+ }
+
+ @Test
+ public void testPeekTokenKeepsPosition() {
+ SimpleParser parser = parserFor("@{test=value}");
+ assertEquals(0, parser.getPos());
+ assertTrue(parser.peekToken('@'));
+ assertEquals(0, parser.getPos());
+
+ assertFalse(parser.peekToken('{'));
+ assertEquals(0, parser.getPos());
+ }
+
+ @Test
+ public void testPeekTokensKeepsPosition() {
+ SimpleParser parser = parserFor("@{test=value}");
+ assertEquals(0, parser.getPos());
+ assertTrue(parser.peekTokens('@', '{'));
+ assertEquals(0, parser.getPos());
+ }
+
+ @Test
+ public void testEatIdentifier() {
+ assertEquals(found("foo"), parserFor("foo").eatIdentifier());
+ assertEquals(found("foo"), parserFor("foo(id)").eatIdentifier());
+ assertEquals(found("foo"), parserFor("foo bar").eatIdentifier());
+
+ assertEquals(found("foo"), parserFor(" foo bar").eatIdentifier());
+ assertEquals(found("foo"), parserFor("\tfoo").eatIdentifier());
+ assertEquals(found("bar"), parserFor("\n bar").eatIdentifier());
+ assertEquals(found("foo"), parserFor(" foo").eatIdentifier());
+ assertEquals(found("foo"), parserFor("foo\"bar\"").eatIdentifier());
+ assertEquals(found("foo"), parserFor("foo.bar").eatIdentifier());
+
+ assertEquals(found("foo"), parserFor("foo) bar").eatIdentifier());
+ assertEquals(found("foo"), parserFor("foo- bar").eatIdentifier());
+ assertEquals(found("foo"), parserFor("foo/ bar").eatIdentifier());
+ assertEquals(found("foo$"), parserFor("foo$ bar").eatIdentifier());
+ assertEquals(found("f$oo"), parserFor("f$oo bar").eatIdentifier());
+ assertEquals(found("_foo"), parserFor("_foo bar").eatIdentifier());
+ assertEquals(found("øfoo"), parserFor("øfoo bar").eatIdentifier());
+
+ assertEquals(NOT_FOUND, parserFor("\"foo").eatIdentifier());
+ assertEquals(NOT_FOUND, parserFor("\\foo").eatIdentifier());
+ assertEquals(NOT_FOUND, parserFor("1foo").eatIdentifier());
+ assertEquals(NOT_FOUND, parserFor("-foo").eatIdentifier());
+ assertEquals(NOT_FOUND, parserFor("$foo").eatIdentifier());
+ assertEquals(NOT_FOUND, parserFor("").eatIdentifier());
+ assertEquals(NOT_FOUND, parserFor(" ").eatIdentifier());
+ assertEquals(NOT_FOUND, parserFor("\n").eatIdentifier());
+ assertEquals(NOT_FOUND, parserFor("/* comment */").eatIdentifier());
+ assertEquals(NOT_FOUND, parserFor("-- comment").eatIdentifier());
+ assertEquals(NOT_FOUND, parserFor("/* comment").eatIdentifier());
+
+ String nestedCommentFollowedByIdentifier =
+ "/* comment /* nested comment */ "
+ + "still a comment if nested comments are supported, "
+ + "and otherwise the start of an identifier. */ test";
+ if (AbstractStatementParser.getInstance(dialect).supportsNestedComments()) {
+ assertEquals(found("test"), parserFor(nestedCommentFollowedByIdentifier).eatIdentifier());
+ } else {
+ // The parser does not look ahead if the rest of the SQL string is malformed. It just reads
+ // from the current position.
+ assertEquals(found("still"), parserFor(nestedCommentFollowedByIdentifier).eatIdentifier());
+ }
+
+ if (AbstractStatementParser.getInstance(dialect).supportsHashSingleLineComments()) {
+ assertEquals(found("test"), parserFor("# comment\ntest").eatIdentifier());
+ } else {
+ // '#' is not a valid start of an identifier.
+ assertEquals(NOT_FOUND, parserFor("# not a comment\ntest").eatIdentifier());
+ }
+ }
+
+ @Test
+ public void testEatSingleQuotedString() {
+ assertEquals(found("test"), parserFor("'test'").eatSingleQuotedString());
+ assertEquals(found("test"), parserFor(" 'test' ").eatSingleQuotedString());
+ assertEquals(found("test"), parserFor("\n'test'").eatSingleQuotedString());
+ assertEquals(found("test"), parserFor("\t'test'").eatSingleQuotedString());
+ assertEquals(found("test test"), parserFor(" 'test test' ").eatSingleQuotedString());
+ assertEquals(found("test\t"), parserFor("'test\t'").eatSingleQuotedString());
+ assertEquals(
+ found("test"), parserFor("/* comment */'test'/*comment*/").eatSingleQuotedString());
+ assertEquals(found("test"), parserFor("-- comment\n'test'--comment\n").eatSingleQuotedString());
+ assertEquals(
+ found("test /* not a comment */"),
+ parserFor("'test /* not a comment */'").eatSingleQuotedString());
+
+ assertEquals(NOT_FOUND, parserFor("test").eatSingleQuotedString());
+ assertEquals(NOT_FOUND, parserFor("'test").eatSingleQuotedString());
+ assertEquals(NOT_FOUND, parserFor("test'").eatSingleQuotedString());
+ assertEquals(NOT_FOUND, parserFor("\"test\"").eatSingleQuotedString());
+ assertEquals(NOT_FOUND, parserFor("'test\n'").eatSingleQuotedString());
+ assertEquals(NOT_FOUND, parserFor("'\ntest'").eatSingleQuotedString());
+ assertEquals(NOT_FOUND, parserFor("'te\nst'").eatSingleQuotedString());
+ }
+
+ @Test
+ public void testEatSingleQuotedStringAdvancesPosition() {
+ SimpleParser parser = parserFor("'test 1' 'test 2' ");
+ assertEquals(found("test 1"), parser.eatSingleQuotedString());
+ assertEquals("'test 1'".length(), parser.getPos());
+ assertEquals(found("test 2"), parser.eatSingleQuotedString());
+ assertEquals("'test 1' 'test 2'".length(), parser.getPos());
+ assertEquals(NOT_FOUND, parser.eatSingleQuotedString());
+ assertEquals(parser.getSql().length(), parser.getPos());
+ }
+}
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementHintParserTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementHintParserTest.java
new file mode 100644
index 00000000000..d1f276849f0
--- /dev/null
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementHintParserTest.java
@@ -0,0 +1,210 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner.connection;
+
+import static com.google.cloud.spanner.connection.StatementHintParser.NO_HINTS;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
+
+import com.google.cloud.spanner.Dialect;
+import com.google.common.collect.ImmutableMap;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(Parameterized.class)
+public class StatementHintParserTest {
+
+ @Parameters(name = "dialect = {0}")
+ public static Object[] data() {
+ return Dialect.values();
+ }
+
+ @Parameter public Dialect dialect;
+
+ StatementHintParser parserFor(String sql) {
+ return new StatementHintParser(dialect, sql);
+ }
+
+ String getStartHint() {
+ return dialect == Dialect.POSTGRESQL ? "/*@" : "@{";
+ }
+
+ String getEndHint() {
+ return dialect == Dialect.POSTGRESQL ? "*/" : "}";
+ }
+
+ String encloseInHint(String sql) {
+ return getStartHint() + sql + getEndHint();
+ }
+
+ @Test
+ public void testNoHints() {
+ assertFalse(parserFor("select foo from bar").hasStatementHints());
+ assertFalse(parserFor("/* comment */ select foo from bar").hasStatementHints());
+ assertFalse(parserFor("select foo from bar").hasStatementHints());
+ assertFalse(parserFor("select foo from bar").hasStatementHints());
+ }
+
+ @Test
+ public void testExtractHints() {
+ StatementHintParser parser;
+
+ parser = parserFor(encloseInHint("statement_tag=tag1") + " select 1");
+ assertTrue(parser.hasStatementHints());
+ assertEquals(ImmutableMap.of("statement_tag", "tag1"), parser.getClientSideStatementHints());
+ assertEquals(" select 1", parser.getSqlWithoutClientSideHints());
+
+ parser = parserFor(encloseInHint("statement_tag=tag1, other_hint=value") + " select 1");
+ assertTrue(parser.hasStatementHints());
+ assertEquals(ImmutableMap.of("statement_tag", "tag1"), parser.getClientSideStatementHints());
+ assertEquals(
+ encloseInHint(" other_hint=value") + " select 1", parser.getSqlWithoutClientSideHints());
+
+ parser = parserFor(encloseInHint("other_hint=value") + " select 1");
+ assertTrue(parser.hasStatementHints());
+ assertEquals(NO_HINTS, parser.getClientSideStatementHints());
+ assertEquals(
+ encloseInHint("other_hint=value") + " select 1", parser.getSqlWithoutClientSideHints());
+
+ parser = parserFor(encloseInHint("statement_tag=tag1, rpc_priority=high") + " select 1");
+ assertTrue(parser.hasStatementHints());
+ assertEquals(
+ ImmutableMap.of("statement_tag", "tag1", "rpc_priority", "high"),
+ parser.getClientSideStatementHints());
+ assertEquals(" select 1", parser.getSqlWithoutClientSideHints());
+
+ parser = parserFor(encloseInHint("rpc_priority=medium, statement_tag='value 2'") + " select 1");
+ assertTrue(parser.hasStatementHints());
+ assertEquals(
+ ImmutableMap.of("rpc_priority", "medium", "statement_tag", "value 2"),
+ parser.getClientSideStatementHints());
+ assertEquals(" select 1", parser.getSqlWithoutClientSideHints());
+
+ parser =
+ parserFor(
+ "/* comment */ "
+ + encloseInHint(
+ "/*comment*/statement_tag--comment\n"
+ + "=--comment\nvalue1\n,rpc_priority=Low/*comment*/")
+ + " /* yet another comment */ select 1");
+ assertTrue(parser.hasStatementHints());
+ assertEquals(
+ ImmutableMap.of("statement_tag", "value1", "rpc_priority", "Low"),
+ parser.getClientSideStatementHints());
+ assertEquals(" /* yet another comment */ select 1", parser.getSqlWithoutClientSideHints());
+
+ parser =
+ parserFor(
+ "/* comment */ "
+ + encloseInHint(
+ "/*comment*/statement_tag--comment\n"
+ + "=--comment\nvalue1\n,"
+ + "/* other hint comment */ other_hint='some value',\n"
+ + "rpc_priority=Low/*comment*/")
+ + " /* yet another comment */ select 1");
+ assertTrue(parser.hasStatementHints());
+ assertEquals(
+ ImmutableMap.of("statement_tag", "value1", "rpc_priority", "Low"),
+ parser.getClientSideStatementHints());
+ assertEquals(
+ "/* comment */ "
+ + encloseInHint(
+ "/*comment*//* other hint comment */ other_hint='some value',\n" + "/*comment*/")
+ + " /* yet another comment */ select 1",
+ parser.getSqlWithoutClientSideHints());
+
+ parser =
+ parserFor(
+ encloseInHint(
+ "statement_tag=tag1,\n"
+ + "other_hint1='some value',\n"
+ + "rpc_priority=low,\n"
+ + "other_hint2=value")
+ + "\nselect 1");
+ assertTrue(parser.hasStatementHints());
+ assertEquals(
+ ImmutableMap.of("statement_tag", "tag1", "rpc_priority", "low"),
+ parser.getClientSideStatementHints());
+ assertEquals(
+ encloseInHint("\nother_hint1='some value',\n" + "\n" + "other_hint2=value") + "\nselect 1",
+ parser.getSqlWithoutClientSideHints());
+
+ parser =
+ parserFor(
+ encloseInHint(
+ "hint1=value1,\n"
+ + "other_hint1='some value',\n"
+ + "rpc_priority=low,\n"
+ + "other_hint2=value")
+ + "\nselect 1");
+ assertTrue(parser.hasStatementHints());
+ assertEquals(ImmutableMap.of("rpc_priority", "low"), parser.getClientSideStatementHints());
+ assertEquals(
+ encloseInHint(
+ "hint1=value1,\n" + "other_hint1='some value',\n" + "\n" + "other_hint2=value")
+ + "\nselect 1",
+ parser.getSqlWithoutClientSideHints());
+
+ parser =
+ parserFor(
+ encloseInHint(
+ "hint1=value1,\n"
+ + "hint2=value2,\n"
+ + "rpc_priority=low,\n"
+ + "statement_tag=tag")
+ + "\nselect 1");
+ assertTrue(parser.hasStatementHints());
+ assertEquals(
+ ImmutableMap.of("rpc_priority", "low", "statement_tag", "tag"),
+ parser.getClientSideStatementHints());
+ assertEquals(
+ encloseInHint("hint1=value1,\nhint2=value2,\n\n") + "\nselect 1",
+ parser.getSqlWithoutClientSideHints());
+ }
+
+ @Test
+ public void testExtractInvalidHints() {
+ assertInvalidHints("@{statement_tag=value value}");
+ assertInvalidHints("@statement_tag=value");
+ assertInvalidHints("{statement_tag=value}");
+ assertInvalidHints("@{statement_tag=value");
+ assertInvalidHints("@{statement_tag=value,");
+ assertInvalidHints("@{statement_tag=value,}");
+ assertInvalidHints("@statement_tag=value}");
+ assertInvalidHints("@{statement_tag=}");
+ assertInvalidHints("@{=value}");
+ assertInvalidHints("@{}");
+ assertInvalidHints("@{statement_tag=value,}");
+ assertInvalidHints("@{statement_tag=value1,hint2=value2,}");
+ assertInvalidHints("@{@statement_tag=value1}");
+ assertInvalidHints("@{statement_tag=@value1}");
+ assertInvalidHints("@{statement_tag value1}");
+ assertInvalidHints("@{statement_tag='value1}");
+ assertInvalidHints("@{statement_tag=value1'}");
+ }
+
+ private void assertInvalidHints(String sql) {
+ StatementHintParser parser = parserFor(sql);
+ assertEquals(NO_HINTS, parser.getClientSideStatementHints());
+ assertSame(sql, parser.getSqlWithoutClientSideHints());
+ }
+}
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java
index 3739aa11064..57758886738 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/StatementParserTest.java
@@ -1575,8 +1575,9 @@ public void testPostgreSQLReturningClause() {
parser.parse(Statement.of("insert into t1 select 10.returning*")).hasReturningClause());
}
- int skipSingleLineComment(String sql, int startIndex) {
- return AbstractStatementParser.skipSingleLineComment(sql, startIndex, null);
+ int skipSingleLineComment(String sql, int prefixLength, int startIndex) {
+ return AbstractStatementParser.skipSingleLineComment(
+ dialect, sql, prefixLength, startIndex, null);
}
int skipMultiLineComment(String sql, int startIndex) {
@@ -1606,12 +1607,12 @@ public void testConcatenatedLiterals() {
public void testSkipSingleLineComment() {
assumeTrue(dialect == Dialect.POSTGRESQL);
- assertEquals(7, skipSingleLineComment("-- foo\n", 0));
- assertEquals(7, skipSingleLineComment("-- foo\nbar", 0));
- assertEquals(6, skipSingleLineComment("-- foo", 0));
- assertEquals(11, skipSingleLineComment("bar -- foo\n", 4));
- assertEquals(11, skipSingleLineComment("bar -- foo\nbar", 4));
- assertEquals(10, skipSingleLineComment("bar -- foo", 4));
+ assertEquals(7, skipSingleLineComment("-- foo\n", 2, 0));
+ assertEquals(7, skipSingleLineComment("-- foo\nbar", 2, 0));
+ assertEquals(6, skipSingleLineComment("-- foo", 2, 0));
+ assertEquals(11, skipSingleLineComment("bar -- foo\n", 2, 4));
+ assertEquals(11, skipSingleLineComment("bar -- foo\nbar", 2, 4));
+ assertEquals(10, skipSingleLineComment("bar -- foo", 2, 4));
}
@Test
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/TaggingTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/TaggingTest.java
index ada17bca219..80210cb7094 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/TaggingTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/TaggingTest.java
@@ -18,25 +18,49 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.ErrorCode;
+import com.google.cloud.spanner.MockSpannerServiceImpl;
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.Statement;
import com.google.spanner.v1.CommitRequest;
import com.google.spanner.v1.ExecuteBatchDmlRequest;
import com.google.spanner.v1.ExecuteSqlRequest;
-import java.util.Arrays;
+import java.util.Collections;
import org.junit.After;
+import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
-@RunWith(JUnit4.class)
+@RunWith(Parameterized.class)
public class TaggingTest extends AbstractMockServerTest {
+ @Parameters(name = "dialect = {0}")
+ public static Object[] data() {
+ return Dialect.values();
+ }
+
+ @Parameter public Dialect dialect;
+
+ private Dialect currentDialect;
+
+ @Before
+ public void setupDialect() {
+ if (currentDialect != dialect) {
+ mockSpanner.putStatementResult(
+ MockSpannerServiceImpl.StatementResult.detectDialectResult(dialect));
+ SpannerPool.closeSpannerPool();
+ currentDialect = dialect;
+ }
+ }
+
@After
public void clearRequests() {
mockSpanner.clearRequests();
@@ -46,12 +70,8 @@ public void clearRequests() {
public void testStatementTagNotAllowedForCommit() {
try (Connection connection = createConnection()) {
connection.setStatementTag("tag-1");
- try {
- connection.commit();
- fail("missing expected exception");
- } catch (SpannerException e) {
- assertEquals(ErrorCode.FAILED_PRECONDITION, e.getErrorCode());
- }
+ SpannerException exception = assertThrows(SpannerException.class, connection::commit);
+ assertEquals(ErrorCode.FAILED_PRECONDITION, exception.getErrorCode());
}
}
@@ -59,12 +79,8 @@ public void testStatementTagNotAllowedForCommit() {
public void testStatementTagNotAllowedForRollback() {
try (Connection connection = createConnection()) {
connection.setStatementTag("tag-1");
- try {
- connection.rollback();
- fail("missing expected exception");
- } catch (SpannerException e) {
- assertEquals(ErrorCode.FAILED_PRECONDITION, e.getErrorCode());
- }
+ SpannerException exception = assertThrows(SpannerException.class, connection::rollback);
+ assertEquals(ErrorCode.FAILED_PRECONDITION, exception.getErrorCode());
}
}
@@ -74,12 +90,11 @@ public void testStatementTagNotAllowedInsideBatch() {
for (boolean autocommit : new boolean[] {true, false}) {
connection.setAutocommit(autocommit);
connection.startBatchDml();
- try {
- connection.setStatementTag("tag-1");
- fail("missing expected exception");
- } catch (SpannerException e) {
- assertEquals(ErrorCode.FAILED_PRECONDITION, e.getErrorCode());
- }
+
+ SpannerException exception =
+ assertThrows(SpannerException.class, () -> connection.setStatementTag("tag-1"));
+ assertEquals(ErrorCode.FAILED_PRECONDITION, exception.getErrorCode());
+
connection.abortBatch();
}
}
@@ -90,7 +105,8 @@ public void testQuery_NoTags() {
try (Connection connection = createConnection()) {
for (boolean autocommit : new boolean[] {true, false}) {
connection.setAutocommit(autocommit);
- try (ResultSet rs = connection.executeQuery(SELECT_COUNT_STATEMENT)) {}
+ //noinspection EmptyTryBlock
+ try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
assertEquals(
"",
@@ -172,7 +188,7 @@ public void testBatchUpdate_NoTags() {
try (Connection connection = createConnection()) {
for (boolean autocommit : new boolean[] {true, false}) {
connection.setAutocommit(autocommit);
- connection.executeBatchUpdate(Arrays.asList(INSERT_STATEMENT));
+ connection.executeBatchUpdate(Collections.singletonList(INSERT_STATEMENT));
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
assertEquals(
@@ -201,7 +217,8 @@ public void testQuery_StatementTag() {
for (boolean autocommit : new boolean[] {true, false}) {
connection.setAutocommit(autocommit);
connection.setStatementTag("tag-1");
- try (ResultSet rs = connection.executeQuery(SELECT_COUNT_STATEMENT)) {}
+ //noinspection EmptyTryBlock
+ try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
assertEquals(
"tag-1",
@@ -221,7 +238,8 @@ public void testQuery_StatementTag() {
mockSpanner.clearRequests();
// The tag should automatically be cleared after a statement.
- try (ResultSet rs = connection.executeQuery(SELECT_COUNT_STATEMENT)) {}
+ //noinspection EmptyTryBlock
+ try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {}
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
assertEquals(
"",
@@ -346,7 +364,7 @@ public void testBatchUpdate_StatementTag() {
for (boolean autocommit : new boolean[] {true, false}) {
connection.setAutocommit(autocommit);
connection.setStatementTag("tag-3");
- connection.executeBatchUpdate(Arrays.asList(INSERT_STATEMENT));
+ connection.executeBatchUpdate(Collections.singletonList(INSERT_STATEMENT));
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
assertEquals(
@@ -366,7 +384,7 @@ public void testBatchUpdate_StatementTag() {
mockSpanner.clearRequests();
- connection.executeBatchUpdate(Arrays.asList(INSERT_STATEMENT));
+ connection.executeBatchUpdate(Collections.singletonList(INSERT_STATEMENT));
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
assertEquals(
@@ -393,7 +411,8 @@ public void testBatchUpdate_StatementTag() {
public void testQuery_TransactionTag() {
try (Connection connection = createConnection()) {
connection.setTransactionTag("tag-1");
- try (ResultSet rs = connection.executeQuery(SELECT_COUNT_STATEMENT)) {}
+ //noinspection EmptyTryBlock
+ try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {}
connection.commit();
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
@@ -430,7 +449,8 @@ public void testQuery_TransactionTag() {
mockSpanner.clearRequests();
// The tag should automatically be cleared after a statement.
- try (ResultSet rs = connection.executeQuery(SELECT_COUNT_STATEMENT)) {}
+ //noinspection EmptyTryBlock
+ try (ResultSet ignore = connection.executeQuery(SELECT_COUNT_STATEMENT)) {}
connection.commit();
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
@@ -550,7 +570,7 @@ public void testUpdate_TransactionTag() {
public void testBatchUpdate_TransactionTag() {
try (Connection connection = createConnection()) {
connection.setTransactionTag("tag-3");
- connection.executeBatchUpdate(Arrays.asList(INSERT_STATEMENT));
+ connection.executeBatchUpdate(Collections.singletonList(INSERT_STATEMENT));
connection.commit();
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
@@ -586,7 +606,7 @@ public void testBatchUpdate_TransactionTag() {
mockSpanner.clearRequests();
- connection.executeBatchUpdate(Arrays.asList(INSERT_STATEMENT));
+ connection.executeBatchUpdate(Collections.singletonList(INSERT_STATEMENT));
connection.commit();
assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
@@ -705,34 +725,170 @@ public void testRunBatch_TransactionTag() {
@Test
public void testShowSetTags() {
try (Connection connection = createConnection()) {
- connection.execute(Statement.of("SET STATEMENT_TAG='tag1'"));
+ connection.execute(Statement.of(String.format("SET %sSTATEMENT_TAG='tag1'", prefix())));
try (ResultSet rs =
- connection.execute(Statement.of("SHOW VARIABLE STATEMENT_TAG")).getResultSet()) {
+ connection
+ .execute(Statement.of(String.format("SHOW VARIABLE %sSTATEMENT_TAG", prefix())))
+ .getResultSet()) {
assertTrue(rs.next());
- assertEquals("tag1", rs.getString("STATEMENT_TAG"));
+ assertEquals("tag1", rs.getString(String.format("%sSTATEMENT_TAG", prefix())));
assertFalse(rs.next());
}
- connection.execute(Statement.of("SET STATEMENT_TAG=''"));
+ connection.execute(Statement.of(String.format("SET %sSTATEMENT_TAG=''", prefix())));
try (ResultSet rs =
- connection.execute(Statement.of("SHOW VARIABLE STATEMENT_TAG")).getResultSet()) {
+ connection
+ .execute(Statement.of(String.format("SHOW VARIABLE %sSTATEMENT_TAG", prefix())))
+ .getResultSet()) {
assertTrue(rs.next());
- assertEquals("", rs.getString("STATEMENT_TAG"));
+ assertEquals("", rs.getString(String.format("%sSTATEMENT_TAG", prefix())));
assertFalse(rs.next());
}
- connection.execute(Statement.of("SET TRANSACTION_TAG='tag2'"));
+ connection.execute(Statement.of(String.format("SET %sTRANSACTION_TAG='tag2'", prefix())));
try (ResultSet rs =
- connection.execute(Statement.of("SHOW VARIABLE TRANSACTION_TAG")).getResultSet()) {
+ connection
+ .execute(Statement.of(String.format("SHOW VARIABLE %sTRANSACTION_TAG", prefix())))
+ .getResultSet()) {
assertTrue(rs.next());
- assertEquals("tag2", rs.getString("TRANSACTION_TAG"));
+ assertEquals("tag2", rs.getString(String.format("%sTRANSACTION_TAG", prefix())));
assertFalse(rs.next());
}
- connection.execute(Statement.of("SET TRANSACTION_TAG=''"));
+ connection.execute(Statement.of(String.format("SET %sTRANSACTION_TAG=''", prefix())));
try (ResultSet rs =
- connection.execute(Statement.of("SHOW VARIABLE TRANSACTION_TAG")).getResultSet()) {
+ connection
+ .execute(Statement.of(String.format("SHOW VARIABLE %sTRANSACTION_TAG", prefix())))
+ .getResultSet()) {
assertTrue(rs.next());
- assertEquals("", rs.getString("TRANSACTION_TAG"));
+ assertEquals("", rs.getString(String.format("%sTRANSACTION_TAG", prefix())));
assertFalse(rs.next());
}
}
}
+
+ @Test
+ public void testQuery_StatementTagHint() {
+ String sql = SELECT_COUNT_STATEMENT.getSql();
+
+ try (Connection connection = createConnection()) {
+ for (boolean autocommit : new boolean[] {true, false}) {
+ connection.setAutocommit(autocommit);
+ //noinspection EmptyTryBlock
+ try (ResultSet ignore =
+ connection.executeQuery(Statement.of(statementTagHint("tag-1") + sql))) {}
+ assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(
+ "tag-1",
+ mockSpanner
+ .getRequestsOfType(ExecuteSqlRequest.class)
+ .get(0)
+ .getRequestOptions()
+ .getRequestTag());
+
+ mockSpanner.clearRequests();
+ }
+ }
+ }
+
+ @Test
+ public void testUpdate_StatementTagHint() {
+ String sql = INSERT_STATEMENT.getSql();
+
+ try (Connection connection = createConnection()) {
+ for (boolean autocommit : new boolean[] {true, false}) {
+ connection.setAutocommit(autocommit);
+ connection.executeUpdate(Statement.of(statementTagHint("tag-2") + sql));
+
+ assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(
+ "tag-2",
+ mockSpanner
+ .getRequestsOfType(ExecuteSqlRequest.class)
+ .get(0)
+ .getRequestOptions()
+ .getRequestTag());
+ assertEquals(
+ "",
+ mockSpanner
+ .getRequestsOfType(ExecuteSqlRequest.class)
+ .get(0)
+ .getRequestOptions()
+ .getTransactionTag());
+
+ mockSpanner.clearRequests();
+ }
+ }
+ }
+
+ @Test
+ public void testPartitionedUpdate_StatementTagHint() {
+ String sql = INSERT_STATEMENT.getSql();
+
+ try (Connection connection = createConnection()) {
+ connection.setAutocommit(true);
+ connection.setAutocommitDmlMode(AutocommitDmlMode.PARTITIONED_NON_ATOMIC);
+ connection.executeUpdate(Statement.of(statementTagHint("tag-4") + sql));
+
+ assertEquals(1, mockSpanner.countRequestsOfType(ExecuteSqlRequest.class));
+ assertEquals(
+ "tag-4",
+ mockSpanner
+ .getRequestsOfType(ExecuteSqlRequest.class)
+ .get(0)
+ .getRequestOptions()
+ .getRequestTag());
+ assertEquals(
+ "",
+ mockSpanner
+ .getRequestsOfType(ExecuteSqlRequest.class)
+ .get(0)
+ .getRequestOptions()
+ .getTransactionTag());
+
+ mockSpanner.clearRequests();
+ }
+ }
+
+ @Test
+ public void testBatchUpdate_StatementTagHint() {
+ String sql = INSERT_STATEMENT.getSql();
+
+ try (Connection connection = createConnection()) {
+ for (boolean autocommit : new boolean[] {true, false}) {
+ connection.setAutocommit(autocommit);
+ connection.executeBatchUpdate(
+ Collections.singletonList(Statement.of(statementTagHint("tag-3") + sql)));
+
+ assertEquals(1, mockSpanner.countRequestsOfType(ExecuteBatchDmlRequest.class));
+ assertEquals(
+ "tag-3",
+ mockSpanner
+ .getRequestsOfType(ExecuteBatchDmlRequest.class)
+ .get(0)
+ .getRequestOptions()
+ .getRequestTag());
+ assertEquals(
+ "",
+ mockSpanner
+ .getRequestsOfType(ExecuteBatchDmlRequest.class)
+ .get(0)
+ .getRequestOptions()
+ .getTransactionTag());
+
+ mockSpanner.clearRequests();
+ }
+ }
+ }
+
+ private String statementTagHint(String tag) {
+ switch (dialect) {
+ case POSTGRESQL:
+ return "/*@statement_tag='" + tag + "'*/";
+ case GOOGLE_STANDARD_SQL:
+ default:
+ return "@{statement_tag='" + tag + "'}";
+ }
+ }
+
+ private String prefix() {
+ return dialect == Dialect.POSTGRESQL ? "SPANNER." : "";
+ }
}