Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve async sharding #977

Merged
merged 6 commits into from
Jun 11, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed test
  • Loading branch information
daniel-sanche committed Jun 5, 2024
commit 48f8bc232f6d5e21ebe5034844afafd9f5e6a958
89 changes: 40 additions & 49 deletions tests/unit/data/_async/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,62 +1927,53 @@ async def mock_call(*args, **kwargs):
assert call_time < 0.2

@pytest.mark.asyncio
async def test_read_rows_sharded_batching(self):
async def test_read_rows_sharded_concurrency_limit(self):
"""
Large queries should be processed in batches to limit concurrency
operation timeout should change between batches
Only 10 queries should be processed concurrently. Others should be queued

Should start a new query as soon as previous finishes
"""
from google.cloud.bigtable.data._async.client import TableAsync
from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT

assert _CONCURRENCY_LIMIT == 10 # change this test if this changes
num_queries = 15

n_queries = 90
expected_num_batches = n_queries // _CONCURRENCY_LIMIT
query_list = [ReadRowsQuery() for _ in range(n_queries)]

table_mock = AsyncMock()
start_operation_timeout = 10
start_attempt_timeout = 3
table_mock.default_read_rows_operation_timeout = start_operation_timeout
table_mock.default_read_rows_attempt_timeout = start_attempt_timeout
# clock ticks one second on each check
with mock.patch("time.monotonic", side_effect=range(0, 100000)):
with mock.patch("asyncio.gather", AsyncMock()) as gather_mock:
await TableAsync.read_rows_sharded(table_mock, query_list)
# should have individual calls for each query
assert table_mock.read_rows.call_count == n_queries
# should have single gather call for each batch
assert gather_mock.call_count == expected_num_batches
# ensure that timeouts decrease over time
kwargs = [
table_mock.read_rows.call_args_list[idx][1]
for idx in range(n_queries)
]
for batch_idx in range(expected_num_batches):
batch_kwargs = kwargs[
batch_idx
* _CONCURRENCY_LIMIT : (batch_idx + 1)
* _CONCURRENCY_LIMIT
# each of the first 10 queries take longer than the last
# later rpcs will have to wait on first 10
increment_time = 0.05
max_time = increment_time * (_CONCURRENCY_LIMIT - 1)
rpc_times = [min(i * increment_time, max_time) for i in range(num_queries)]

async def mock_call(*args, **kwargs):
next_sleep = rpc_times.pop(0)
await asyncio.sleep(next_sleep)
return [mock.Mock()]

starting_timeout = 10

async with _make_client() as client:
async with client.get_table("instance", "table") as table:
with mock.patch.object(table, "read_rows") as read_rows:
read_rows.side_effect = mock_call
queries = [ReadRowsQuery() for _ in range(num_queries)]
await table.read_rows_sharded(
queries, operation_timeout=starting_timeout
)
assert read_rows.call_count == num_queries
# check operation timeouts to see how far into the operation each rpc started
rpc_start_list = [
starting_timeout - kwargs["operation_timeout"]
for _, kwargs in read_rows.call_args_list
]
for req_kwargs in batch_kwargs:
# each batch should have the same operation_timeout, and it should decrease in each batch
expected_operation_timeout = start_operation_timeout - (
batch_idx + 1
)
assert (
req_kwargs["operation_timeout"]
== expected_operation_timeout
)
# each attempt_timeout should start with default value, but decrease when operation_timeout reaches it
expected_attempt_timeout = min(
start_attempt_timeout, expected_operation_timeout
)
assert req_kwargs["attempt_timeout"] == expected_attempt_timeout
# await all created coroutines to avoid warnings
for i in range(len(gather_mock.call_args_list)):
for j in range(len(gather_mock.call_args_list[i][0])):
await gather_mock.call_args_list[i][0][j]
eps = 0.01
# first 10 should start immediately
assert all(
rpc_start_list[i] < eps for i in range(_CONCURRENCY_LIMIT)
)
# next rpcs should start as first ones finish
for i in range(num_queries - _CONCURRENCY_LIMIT):
idx = i + _CONCURRENCY_LIMIT
assert rpc_start_list[idx] - (i * increment_time) < eps


class TestSampleRowKeys:
Expand Down
Loading