Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions nhs_aws_helpers/dynamodb_model_store/base_model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,25 @@ async def paginate_models(
last_evaluated_key=last_evaluated_key,
)

async def query_models_from_index(
self,
model_type: Type[TBaseModel_co],
index_name: str,
max_concurrency=10,
**kwargs,
) -> PagedItems[TBaseModel_co]:
items, last_evaluated_key = await self.query_items(IndexName=index_name, **kwargs)
model_keys = [model_type.model_key_from_item(item) for item in items]
models = [
cast(TBaseModel_co, record)
for record in await self.batch_get_model(model_keys, max_concurrency=max_concurrency)
if record
]
return PagedItems(
items=models,
last_evaluated_key=last_evaluated_key,
)

async def paginate_models_from_index(
self,
paginator_type: Literal["query", "scan", "list_backups", "list_tables"],
Expand Down
31 changes: 30 additions & 1 deletion tests/base_model_store_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@
from dataclasses import dataclass, field
from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, Generator, List, Mapping, Optional, Type, TypedDict, Union, cast
from typing import (
Any,
Dict,
Generator,
List,
Mapping,
Optional,
Type,
TypedDict,
Union,
cast,
)
from uuid import uuid4

import petname # type: ignore[import]
Expand Down Expand Up @@ -744,6 +755,24 @@ async def test_paginate_models_from_index(store: MyModelStore):
assert all(isinstance(model, AnotherModel) for model in page.items)


async def test_query_models_from_index(store: MyModelStore):
async with store.batch_writer() as writer:
for i in range(100):
await writer.put_item(AnotherModel(id="1", sk_field=f"SK:{i}"))


response = await store.query_models_from_index(
model_type=AnotherModel,
index_name="gsi_model_type",
KeyConditionExpression=Key("model_type").eq(AnotherModel.__name__),
Limit=40,
)

assert len(response.items) == 40
assert all(isinstance(model, AnotherModel) for model in response.items)
assert response.last_evaluated_key == {"model_type": "AnotherModel", "my_pk": "BB#1", "my_sk": "SK:47"}


async def test_paged_items():
has_size = PagedItems(items=[1, 2, 3], last_evaluated_key={"my_pk": 123})
assert has_size
Expand Down
Loading