Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve async sharding #977

Merged
merged 6 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 34 additions & 29 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,43 +740,48 @@ async def read_rows_sharded(
"""
if not sharded_query:
raise ValueError("empty sharded_query")
# reduce operation_timeout between batches
operation_timeout, attempt_timeout = _get_timeouts(
operation_timeout, attempt_timeout, self
)
timeout_generator = _attempt_timeout_generator(
# make sure each rpc stays within overall operation timeout
rpc_timeout_generator = _attempt_timeout_generator(
operation_timeout, operation_timeout
)
# submit shards in batches if the number of shards goes over _CONCURRENCY_LIMIT
batched_queries = [
sharded_query[i : i + _CONCURRENCY_LIMIT]
for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT)
]
# run batches and collect results
results_list = []
error_dict = {}
shard_idx = 0
for batch in batched_queries:
batch_operation_timeout = next(timeout_generator)
routine_list = [
self.read_rows(

# limit the number of concurrent requests using a semaphore
concurrency_sem = asyncio.Semaphore(_CONCURRENCY_LIMIT)

async def read_rows_with_semaphore(query):
async with concurrency_sem:
# calculate new timeout based on time left in overall operation
shard_timeout = next(rpc_timeout_generator)
if shard_timeout <= 0:
raise DeadlineExceeded(
"Operation timeout exceeded before starting query"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe before starting subquery?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the earliest we can check after getting through the semaphore.

There is a similar check at the very beginning as part of _get_timeouts though (which raises an AttributeError)

)
return await self.read_rows(
query,
operation_timeout=batch_operation_timeout,
attempt_timeout=min(attempt_timeout, batch_operation_timeout),
operation_timeout=shard_timeout,
attempt_timeout=min(attempt_timeout, shard_timeout),
retryable_errors=retryable_errors,
)
for query in batch
]
batch_result = await asyncio.gather(*routine_list, return_exceptions=True)
for result in batch_result:
if isinstance(result, Exception):
error_dict[shard_idx] = result
elif isinstance(result, BaseException):
# BaseException not expected; raise immediately
raise result
else:
results_list.extend(result)
shard_idx += 1

routine_list = [read_rows_with_semaphore(query) for query in sharded_query]
batch_result = await asyncio.gather(*routine_list, return_exceptions=True)

# collect results and errors
error_dict = {}
shard_idx = 0
results_list = []
for result in batch_result:
if isinstance(result, Exception):
error_dict[shard_idx] = result
elif isinstance(result, BaseException):
# BaseException not expected; raise immediately
raise result
else:
results_list.extend(result)
shard_idx += 1
if error_dict:
# if any sub-request failed, raise an exception instead of returning results
raise ShardedReadRowsExceptionGroup(
Expand Down
155 changes: 107 additions & 48 deletions tests/unit/data/_async/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,62 +1927,121 @@ async def mock_call(*args, **kwargs):
assert call_time < 0.2

@pytest.mark.asyncio
async def test_read_rows_sharded_batching(self):
async def test_read_rows_sharded_concurrency_limit(self):
"""
Large queries should be processed in batches to limit concurrency
operation timeout should change between batches
Only 10 queries should be processed concurrently. Others should be queued

Should start a new query as soon as previous finishes
"""
from google.cloud.bigtable.data._async.client import TableAsync
from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT

assert _CONCURRENCY_LIMIT == 10 # change this test if this changes
num_queries = 15

n_queries = 90
expected_num_batches = n_queries // _CONCURRENCY_LIMIT
query_list = [ReadRowsQuery() for _ in range(n_queries)]

table_mock = AsyncMock()
start_operation_timeout = 10
start_attempt_timeout = 3
table_mock.default_read_rows_operation_timeout = start_operation_timeout
table_mock.default_read_rows_attempt_timeout = start_attempt_timeout
# clock ticks one second on each check
with mock.patch("time.monotonic", side_effect=range(0, 100000)):
with mock.patch("asyncio.gather", AsyncMock()) as gather_mock:
await TableAsync.read_rows_sharded(table_mock, query_list)
# should have individual calls for each query
assert table_mock.read_rows.call_count == n_queries
# should have single gather call for each batch
assert gather_mock.call_count == expected_num_batches
# ensure that timeouts decrease over time
kwargs = [
table_mock.read_rows.call_args_list[idx][1]
for idx in range(n_queries)
]
for batch_idx in range(expected_num_batches):
batch_kwargs = kwargs[
batch_idx
* _CONCURRENCY_LIMIT : (batch_idx + 1)
* _CONCURRENCY_LIMIT
# each of the first 10 queries take longer than the last
# later rpcs will have to wait on first 10
increment_time = 0.05
max_time = increment_time * (_CONCURRENCY_LIMIT - 1)
rpc_times = [min(i * increment_time, max_time) for i in range(num_queries)]

async def mock_call(*args, **kwargs):
next_sleep = rpc_times.pop(0)
await asyncio.sleep(next_sleep)
return [mock.Mock()]

starting_timeout = 10

async with _make_client() as client:
async with client.get_table("instance", "table") as table:
with mock.patch.object(table, "read_rows") as read_rows:
read_rows.side_effect = mock_call
queries = [ReadRowsQuery() for _ in range(num_queries)]
await table.read_rows_sharded(
queries, operation_timeout=starting_timeout
)
assert read_rows.call_count == num_queries
# check operation timeouts to see how far into the operation each rpc started
rpc_start_list = [
starting_timeout - kwargs["operation_timeout"]
for _, kwargs in read_rows.call_args_list
]
for req_kwargs in batch_kwargs:
# each batch should have the same operation_timeout, and it should decrease in each batch
expected_operation_timeout = start_operation_timeout - (
batch_idx + 1
)
assert (
req_kwargs["operation_timeout"]
== expected_operation_timeout
)
# each attempt_timeout should start with default value, but decrease when operation_timeout reaches it
expected_attempt_timeout = min(
start_attempt_timeout, expected_operation_timeout
eps = 0.01
# first 10 should start immediately
assert all(
rpc_start_list[i] < eps for i in range(_CONCURRENCY_LIMIT)
)
# next rpcs should start as first ones finish
for i in range(num_queries - _CONCURRENCY_LIMIT):
idx = i + _CONCURRENCY_LIMIT
assert rpc_start_list[idx] - (i * increment_time) < eps

@pytest.mark.asyncio
async def test_read_rows_sharded_expirary(self):
"""
If the operation times out before all shards complete, should raise
a ShardedReadRowsExceptionGroup
"""
from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT
from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup
from google.api_core.exceptions import DeadlineExceeded

operation_timeout = 0.1

# let the first batch complete, but the next batch times out
num_queries = 15
sleeps = [0] * _CONCURRENCY_LIMIT + [DeadlineExceeded("times up")] * (
num_queries - _CONCURRENCY_LIMIT
)

async def mock_call(*args, **kwargs):
next_item = sleeps.pop(0)
if isinstance(next_item, Exception):
raise next_item
else:
await asyncio.sleep(next_item)
return [mock.Mock()]

async with _make_client() as client:
async with client.get_table("instance", "table") as table:
with mock.patch.object(table, "read_rows") as read_rows:
read_rows.side_effect = mock_call
queries = [ReadRowsQuery() for _ in range(num_queries)]
with pytest.raises(ShardedReadRowsExceptionGroup) as exc:
await table.read_rows_sharded(
queries, operation_timeout=operation_timeout
)
assert req_kwargs["attempt_timeout"] == expected_attempt_timeout
# await all created coroutines to avoid warnings
for i in range(len(gather_mock.call_args_list)):
for j in range(len(gather_mock.call_args_list[i][0])):
await gather_mock.call_args_list[i][0][j]
assert isinstance(exc.value, ShardedReadRowsExceptionGroup)
assert len(exc.value.exceptions) == num_queries - _CONCURRENCY_LIMIT
# should keep successful queries
assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT

@pytest.mark.asyncio
async def test_read_rows_sharded_negative_batch_timeout(self):
"""
try to run with batch that starts after operation timeout

They should raise DeadlineExceeded errors
"""
from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup
from google.api_core.exceptions import DeadlineExceeded

async def mock_call(*args, **kwargs):
await asyncio.sleep(0.05)
return [mock.Mock()]

async with _make_client() as client:
async with client.get_table("instance", "table") as table:
with mock.patch.object(table, "read_rows") as read_rows:
read_rows.side_effect = mock_call
queries = [ReadRowsQuery() for _ in range(15)]
with pytest.raises(ShardedReadRowsExceptionGroup) as exc:
await table.read_rows_sharded(queries, operation_timeout=0.01)
assert isinstance(exc.value, ShardedReadRowsExceptionGroup)
assert len(exc.value.exceptions) == 5
assert all(
isinstance(e.__cause__, DeadlineExceeded)
for e in exc.value.exceptions
)


class TestSampleRowKeys:
Expand Down
Loading