Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,63 @@ users may find useful:
Consider multi-session for potential cost savings, but be mindful of
performance impacts from shared resources. You might need to adjust
cluster size if slowdowns occur, which could affect overall cost.

### Storing results to cloud storage

Instead of receiving query results inline over the WebSocket connection,
you can have the server write them to cloud storage (S3) using the
`Store` class. This is useful for large result sets or when you need a
downloadable file.

```python
from wherobots.db import connect, Store, StorageFormat
from wherobots.db.region import Region
from wherobots.db.runtime import Runtime

with connect(
api_key='...',
runtime=Runtime.TINY,
region=Region.AWS_US_WEST_2) as conn:
curr = conn.cursor()

# Store results as a single GeoJSON file with a presigned download URL
store = Store.for_download(format=StorageFormat.GEOJSON)
curr.execute("SELECT * FROM my_table", store=store)
results = curr.fetchall()
```

#### Store options

You can pass format-specific Spark write options through the `options`
parameter. These correspond to the options available in Spark's
`DataFrameWriter` and are applied after the server's default options,
allowing you to override them.

```python
# CSV without headers
store = Store.for_download(
format=StorageFormat.CSV,
options={"header": "false", "delimiter": "|"},
)

# GeoJSON preserving null fields
store = Store.for_download(
format=StorageFormat.GEOJSON,
options={"ignoreNullFields": "false"},
)
```

You can also set a default `Store` at connection time, which will be
used for all queries executed through cursors created from that
connection unless overridden per-query:

```python
with connect(
api_key='...',
runtime=Runtime.TINY,
region=Region.AWS_US_WEST_2,
store=Store.for_download(format=StorageFormat.PARQUET)) as conn:
curr = conn.cursor()
# All queries through this cursor will use the connection-level store
curr.execute("SELECT * FROM my_table")
```
170 changes: 170 additions & 0 deletions tests/test_result_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Tests for result_store module: Store dataclass and StorageFormat enum."""

import json
import pytest

from wherobots.db.result_store import Store, StorageFormat, DEFAULT_STORAGE_FORMAT


class TestStorageFormat:
def test_values(self):
assert StorageFormat.PARQUET.value == "parquet"
assert StorageFormat.CSV.value == "csv"
assert StorageFormat.GEOJSON.value == "geojson"

def test_default_format(self):
assert DEFAULT_STORAGE_FORMAT == StorageFormat.PARQUET


class TestStore:
def test_default_construction(self):
store = Store()
assert store.format == StorageFormat.PARQUET
assert store.single is False
assert store.generate_presigned_url is False
assert store.options is None

def test_with_format(self):
store = Store(format=StorageFormat.CSV)
assert store.format == StorageFormat.CSV
assert store.options is None

def test_with_options(self):
store = Store(
format=StorageFormat.GEOJSON,
options={"ignoreNullFields": "false"},
)
assert store.options == {"ignoreNullFields": "false"}

def test_with_multiple_options(self):
opts = {"header": "false", "delimiter": "|", "quote": '"'}
store = Store(format=StorageFormat.CSV, options=opts)
assert store.options == opts

def test_empty_options_normalized_to_none(self):
store = Store(options={})
assert store.options is None

def test_none_options(self):
store = Store(options=None)
assert store.options is None

def test_options_defensively_copied(self):
original = {"key": "value"}
store = Store(options=original)
# Mutating the original should not affect the store
original["key"] = "changed"
assert store.options == {"key": "value"}

def test_frozen_dataclass(self):
store = Store()
with pytest.raises(AttributeError):
store.format = StorageFormat.CSV

def test_presigned_url_requires_single(self):
with pytest.raises(ValueError, match="single=True"):
Store(generate_presigned_url=True, single=False)

def test_presigned_url_with_single(self):
store = Store(single=True, generate_presigned_url=True)
assert store.single is True
assert store.generate_presigned_url is True


class TestStoreForDownload:
def test_default(self):
store = Store.for_download()
assert store.format == StorageFormat.PARQUET
assert store.single is True
assert store.generate_presigned_url is True
assert store.options is None

def test_with_format(self):
store = Store.for_download(format=StorageFormat.CSV)
assert store.format == StorageFormat.CSV
assert store.single is True
assert store.generate_presigned_url is True

def test_with_options(self):
store = Store.for_download(
format=StorageFormat.GEOJSON,
options={"ignoreNullFields": "false"},
)
assert store.format == StorageFormat.GEOJSON
assert store.options == {"ignoreNullFields": "false"}


class TestStoreToDict:
def test_without_options(self):
store = Store(format=StorageFormat.PARQUET, single=True)
d = store.to_dict()
assert d == {
"format": "parquet",
"single": True,
"generate_presigned_url": False,
}
assert "options" not in d

def test_with_options(self):
store = Store(
format=StorageFormat.GEOJSON,
single=True,
generate_presigned_url=True,
options={"ignoreNullFields": "false"},
)
d = store.to_dict()
assert d == {
"format": "geojson",
"single": True,
"generate_presigned_url": True,
"options": {"ignoreNullFields": "false"},
}

def test_serializable_to_json(self):
store = Store.for_download(
format=StorageFormat.CSV,
options={"header": "false"},
)
serialized = json.dumps(store.to_dict())
deserialized = json.loads(serialized)
assert deserialized["format"] == "csv"
assert deserialized["single"] is True
assert deserialized["generate_presigned_url"] is True
assert deserialized["options"] == {"header": "false"}

def test_to_dict_returns_copy(self):
"""Mutating the returned dict should not affect the Store."""
store = Store(options={"key": "value"})
d = store.to_dict()
d["options"]["key"] = "changed"
assert store.options == {"key": "value"}

def test_full_execute_sql_request_shape(self):
"""Verify the dict integrates correctly into an execute_sql request."""
store = Store.for_download(
format=StorageFormat.GEOJSON,
options={"ignoreNullFields": "false"},
)
request = {
"kind": "execute_sql",
"execution_id": "test-id",
"statement": "SELECT 1",
"store": store.to_dict(),
}
serialized = json.dumps(request)
parsed = json.loads(serialized)
assert parsed["store"]["format"] == "geojson"
assert parsed["store"]["single"] is True
assert parsed["store"]["generate_presigned_url"] is True
assert parsed["store"]["options"] == {"ignoreNullFields": "false"}

def test_request_without_store(self):
"""Without a store, the request should not have a store key."""
request = {
"kind": "execute_sql",
"execution_id": "test-id",
"statement": "SELECT 1",
}
serialized = json.dumps(request)
parsed = json.loads(serialized)
assert "store" not in parsed
3 changes: 3 additions & 0 deletions wherobots/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NotSupportedError,
)
from .region import Region
from .result_store import Store, StorageFormat
from .runtime import Runtime

__all__ = [
Expand All @@ -27,4 +28,6 @@
"NotSupportedError",
"Region",
"Runtime",
"Store",
"StorageFormat",
]
15 changes: 13 additions & 2 deletions wherobots/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from wherobots.db.cursor import Cursor
from wherobots.db.errors import NotSupportedError, OperationalError
from wherobots.db.result_store import Store


@dataclass
Expand Down Expand Up @@ -56,12 +57,14 @@ def __init__(
results_format: Union[ResultsFormat, None] = None,
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
store: Union[Store, None] = None,
):
self.__ws = ws
self.__read_timeout = read_timeout
self.__results_format = results_format
self.__data_compression = data_compression
self.__geometry_representation = geometry_representation
self.__store = store

self.__queries: dict[str, Query] = {}
self.__thread = threading.Thread(
Expand All @@ -85,7 +88,7 @@ def rollback(self) -> None:
raise NotSupportedError

def cursor(self) -> Cursor:
return Cursor(self.__execute_sql, self.__cancel_query)
return Cursor(self.__execute_sql, self.__cancel_query, self.__store)

def __main_loop(self) -> None:
"""Main background loop listening for messages from the SQL session."""
Expand Down Expand Up @@ -200,7 +203,12 @@ def __recv(self) -> Dict[str, Any]:
raise ValueError("Unexpected frame type received")
return message

def __execute_sql(self, sql: str, handler: Callable[[Any], None]) -> str:
def __execute_sql(
self,
sql: str,
handler: Callable[[Any], None],
store: Union[Store, None] = None,
) -> str:
"""Triggers the execution of the given SQL query."""
execution_id = str(uuid.uuid4())
request = {
Expand All @@ -209,6 +217,9 @@ def __execute_sql(self, sql: str, handler: Callable[[Any], None]) -> str:
"statement": sql,
}

if store is not None:
request["store"] = store.to_dict()

self.__queries[execution_id] = Query(
sql=sql,
execution_id=execution_id,
Expand Down
23 changes: 19 additions & 4 deletions wherobots/db/cursor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import queue
from typing import Any, Optional, List, Tuple, Dict
from typing import Any, Optional, List, Tuple, Dict, Union

from .errors import DatabaseError, ProgrammingError
from .result_store import Store

_TYPE_MAP = {
"object": "STRING",
Expand All @@ -15,9 +16,15 @@


class Cursor:
def __init__(self, exec_fn, cancel_fn) -> None:
def __init__(
self,
exec_fn,
cancel_fn,
default_store: Union[Store, None] = None,
) -> None:
self.__exec_fn = exec_fn
self.__cancel_fn = cancel_fn
self.__default_store = default_store

self.__queue: queue.Queue = queue.Queue()
self.__results: Optional[list[Any]] = None
Expand Down Expand Up @@ -71,7 +78,12 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]:

return self.__results

def execute(self, operation: str, parameters: Dict[str, Any] = None) -> None:
def execute(
self,
operation: str,
parameters: Dict[str, Any] = None,
store: Union[Store, None] = None,
) -> None:
if self.__current_execution_id:
self.__cancel_fn(self.__current_execution_id)

Expand All @@ -83,7 +95,10 @@ def execute(self, operation: str, parameters: Dict[str, Any] = None) -> None:
sql = (
operation.replace("{", "{{").replace("}", "}}").format(**(parameters or {}))
)
self.__current_execution_id = self.__exec_fn(sql, self.__on_execution_result)
effective_store = store if store is not None else self.__default_store
self.__current_execution_id = self.__exec_fn(
sql, self.__on_execution_result, effective_store
)

def executemany(
self, operation: str, seq_of_parameters: List[Dict[str, Any]]
Expand Down
5 changes: 5 additions & 0 deletions wherobots/db/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
OperationalError,
)
from .region import Region
from .result_store import Store
from .runtime import Runtime

apilevel = "2.0"
Expand Down Expand Up @@ -69,6 +70,7 @@ def connect(
results_format: Union[ResultsFormat, None] = None,
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
store: Union[Store, None] = None,
) -> Connection:
if not token and not api_key:
raise ValueError("At least one of `token` or `api_key` is required")
Expand Down Expand Up @@ -151,6 +153,7 @@ def get_session_uri() -> str:
results_format=results_format,
data_compression=data_compression,
geometry_representation=geometry_representation,
store=store,
)


Expand All @@ -171,6 +174,7 @@ def connect_direct(
results_format: Union[ResultsFormat, None] = None,
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
store: Union[Store, None] = None,
) -> Connection:
uri_with_protocol = f"{uri}/{protocol}"

Expand All @@ -193,4 +197,5 @@ def connect_direct(
results_format=results_format,
data_compression=data_compression,
geometry_representation=geometry_representation,
store=store,
)
Loading