diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 41886fd..e5a8d64 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.14.6 + rev: v0.14.8 hooks: # Run the linter. - id: ruff-check diff --git a/.vscode/extensions.json b/.vscode/extensions.json index f3936b1..5d21c81 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -1,5 +1,5 @@ { "recommendations": [ - "golang.go" + "charliermarsh.ruff" ] } diff --git a/conftest.py b/conftest.py index b75e5b8..8683e7b 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,4 @@ -""" -Pytest configuration and fixtures for the OpenTDF Python SDK tests. +"""Pytest configuration and fixtures for the OpenTDF Python SDK tests. This module contains pytest hooks and fixtures that will be automatically loaded by pytest when running tests. @@ -14,13 +13,13 @@ @pytest.fixture(scope="session") def project_root(request) -> Path: + """Get project root directory.""" return request.config.rootpath # Project root @pytest.hookimpl(tryfirst=True, hookwrapper=True) def pytest_runtest_makereport(item, call): - """ - Hook that runs after each test phase (setup, call, teardown). + """Collect server logs when test fails after each test phase. This hook automatically collects server logs when a test fails. """ @@ -53,8 +52,7 @@ def pytest_runtest_makereport(item, call): @pytest.fixture def collect_server_logs(): - """ - Fixture that provides a function to manually collect server logs. + """Fixture that provides a function to manually collect server logs. Usage: def test_something(collect_server_logs): diff --git a/otdf-python-proto/scripts/generate_connect_proto.py b/otdf-python-proto/scripts/generate_connect_proto.py index 111cd81..9ac1e39 100644 --- a/otdf-python-proto/scripts/generate_connect_proto.py +++ b/otdf-python-proto/scripts/generate_connect_proto.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -""" -Enhanced script to generate Python Connect RPC clients from .proto definitions. +"""Enhanced script to generate Python Connect RPC clients from .proto definitions. This script: 1. Downloads the latest proto files from OpenTDF platform @@ -200,8 +199,7 @@ def create_init_files(generated_dir: Path) -> None: def _fix_ignore_if_default_value(proto_files_dir): - """ - TODO: Fix buf validation: Updated the proto files to use the correct enum value: + """TODO: Fix buf validation: Updated the proto files to use the correct enum value: Changed IGNORE_IF_DEFAULT_VALUE → IGNORE_IF_ZERO_VALUE in: attributes.proto diff --git a/pyproject.toml b/pyproject.toml index 7cd5081..f4a4aa3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ norecursedirs = ["otdf-python-proto"] [tool.ruff] line-length = 88 +target-version = "py310" # See https://docs.astral.sh/ruff/rules/ # for rule information. @@ -80,28 +81,33 @@ lint.ignore = [ "E501", ] lint.select = [ - # pycodestyle checks. - "E", - "W", - # pyflakes checks. - "F", - # flake8-bugbear checks. - "B", - # flake8-comprehensions checks. - "C4", - # McCabe complexity - "C90", - # isort - "I", - # Performance-related rules - "PERF", # Ruff's performance rules - "PTH", # pathlib (path handling) - # Additional useful rules - "UP", # pyupgrade (modern Python features) - "SIM", # flake8-simplify (simplifications) - "RUF", # Ruff-specific rules - "FURB", # refurb (FURB) - "PT018", # flake8-pytest-style (pytest style) + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "C90", # McCabe complexity + "D", # pydocstyle + "DOC", # pydoclint + "E", # pycodestyle errors + "F", # pyflakes + "FURB", # refurb + "I", # isort + "PERF", # performance + "PT018", # pytest style + "PTH", # pathlib + "Q", # flake8-quotes + "RUF", # ruff-specific + "SIM", # flake8-simplify + "UP", # pyupgrade + "W", # pycodestyle warnings ] # Ignore generated files extend-exclude = ["otdf-python-proto/src/"] + +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["D100", "D101", "D102", "D103", "D107", "D400", "D401", "D415"] +"otdf-python-proto/**" = ["D"] # Ignore all D (docstring) rules for generated proto files + +# TODO: Remaining work - 4 buckets to fix (140 errors remaining): +# Bucket #1: D102 (missing method docstrings) - 98 errors +# Bucket #2: D105 (missing magic method docstrings) - 23 errors +# Bucket #3: D205 (blank line formatting), D103, D417, D104 - 19 errors +"src/**" = ["D102", "D105", "D205"] diff --git a/src/otdf_python/__init__.py b/src/otdf_python/__init__.py index f8dcd49..b9686eb 100644 --- a/src/otdf_python/__init__.py +++ b/src/otdf_python/__init__.py @@ -1,5 +1,4 @@ -""" -OpenTDF Python SDK +"""OpenTDF Python SDK. A Python implementation of the OpenTDF SDK for working with Trusted Data Format (TDF) files. Provides both programmatic APIs and command-line interface for encryption and decryption. diff --git a/src/otdf_python/__main__.py b/src/otdf_python/__main__.py index e3e8c97..798d10b 100644 --- a/src/otdf_python/__main__.py +++ b/src/otdf_python/__main__.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -""" -Main entry point for running otdf_python as a module. +"""Main entry point for running otdf_python as a module. This allows the package to be run with `python -m otdf_python` and properly handles the CLI interface without import conflicts. diff --git a/src/otdf_python/address_normalizer.py b/src/otdf_python/address_normalizer.py index e3acf15..14f15d2 100644 --- a/src/otdf_python/address_normalizer.py +++ b/src/otdf_python/address_normalizer.py @@ -1,6 +1,4 @@ -""" -Address normalization utilities for OpenTDF. -""" +"""Address normalization utilities for OpenTDF.""" import logging import re @@ -12,8 +10,7 @@ def normalize_address(url_string: str, use_plaintext: bool) -> str: - """ - Normalize a URL address to ensure it has the correct scheme and port. + """Normalize a URL address to ensure it has the correct scheme and port. Args: url_string: The URL string to normalize @@ -24,6 +21,7 @@ def normalize_address(url_string: str, use_plaintext: bool) -> str: Raises: SDKException: If there's an error parsing or creating the URL + """ scheme = "http" if use_plaintext else "https" diff --git a/src/otdf_python/aesgcm.py b/src/otdf_python/aesgcm.py index ced6427..6385aee 100644 --- a/src/otdf_python/aesgcm.py +++ b/src/otdf_python/aesgcm.py @@ -1,13 +1,18 @@ +"""AES-GCM encryption and decryption functionality.""" + import os from cryptography.hazmat.primitives.ciphers.aead import AESGCM class AesGcm: + """AES-GCM encryption and decryption operations.""" + GCM_NONCE_LENGTH = 12 GCM_TAG_LENGTH = 16 def __init__(self, key: bytes): + """Initialize AES-GCM cipher with key.""" if not key or len(key) not in (16, 24, 32): raise ValueError("Invalid key size for GCM encryption") self.key = key @@ -17,7 +22,10 @@ def get_key(self) -> bytes: return self.key class Encrypted: + """Encrypted data with initialization vector and ciphertext.""" + def __init__(self, iv: bytes, ciphertext: bytes): + """Initialize encrypted data.""" self.iv = iv self.ciphertext = ciphertext diff --git a/src/otdf_python/assertion_config.py b/src/otdf_python/assertion_config.py index 4c96fb2..2eae552 100644 --- a/src/otdf_python/assertion_config.py +++ b/src/otdf_python/assertion_config.py @@ -1,8 +1,12 @@ +"""Assertion configuration for TDF.""" + from enum import Enum, auto from typing import Any class Type(Enum): + """Assertion type enumeration.""" + HANDLING_ASSERTION = "handling" BASE_ASSERTION = "base" @@ -11,6 +15,8 @@ def __str__(self): class Scope(Enum): + """Assertion scope enumeration.""" + TRUSTED_DATA_OBJ = "tdo" PAYLOAD = "payload" @@ -19,12 +25,16 @@ def __str__(self): class AssertionKeyAlg(Enum): + """Assertion key algorithm enumeration.""" + RS256 = auto() HS256 = auto() NOT_DEFINED = auto() class AppliesToState(Enum): + """Assertion applies-to state enumeration.""" + ENCRYPTED = "encrypted" UNENCRYPTED = "unencrypted" @@ -33,6 +43,8 @@ def __str__(self): class BindingMethod(Enum): + """Assertion binding method enumeration.""" + JWS = "jws" def __str__(self): @@ -40,7 +52,10 @@ def __str__(self): class AssertionKey: + """Assertion signing key configuration.""" + def __init__(self, alg: AssertionKeyAlg, key: Any): + """Initialize assertion key.""" self.alg = alg self.key = key @@ -49,7 +64,10 @@ def is_defined(self): class Statement: + """Assertion statement with format, schema, and value.""" + def __init__(self, format: str, schema: str, value: str): + """Initialize assertion statement.""" self.format = format self.schema = schema self.value = value @@ -67,6 +85,8 @@ def __hash__(self): class AssertionConfig: + """TDF assertion configuration.""" + def __init__( self, id: str, @@ -76,6 +96,7 @@ def __init__( statement: Statement, signing_key: AssertionKey | None = None, ): + """Initialize assertion configuration.""" self.id = id self.type = type self.scope = scope diff --git a/src/otdf_python/asym_crypto.py b/src/otdf_python/asym_crypto.py index b77b609..be5b942 100644 --- a/src/otdf_python/asym_crypto.py +++ b/src/otdf_python/asym_crypto.py @@ -1,6 +1,4 @@ -""" -Asymmetric encryption and decryption utilities for RSA keys in PEM format. -""" +"""Asymmetric encryption and decryption utilities for RSA keys in PEM format.""" import base64 import re @@ -14,8 +12,7 @@ class AsymDecryption: - """ - Provides functionality for asymmetric decryption using an RSA private key. + """Provides functionality for asymmetric decryption using an RSA private key. Supports both PEM string and key object initialization for flexibility. """ @@ -25,8 +22,7 @@ class AsymDecryption: PRIVATE_KEY_FOOTER = "-----END PRIVATE KEY-----" def __init__(self, private_key_pem: str | None = None, private_key_obj=None): - """ - Initialize with either a PEM string or a key object. + """Initialize with either a PEM string or a key object. Args: private_key_pem: Private key in PEM format (with or without headers) @@ -34,6 +30,7 @@ def __init__(self, private_key_pem: str | None = None, private_key_obj=None): Raises: SDKException: If key loading fails + """ if private_key_obj is not None: self.private_key = private_key_obj @@ -65,8 +62,7 @@ def __init__(self, private_key_pem: str | None = None, private_key_obj=None): self.private_key = None def decrypt(self, data: bytes) -> bytes: - """ - Decrypt data using RSA OAEP with SHA-1. + """Decrypt data using RSA OAEP with SHA-1. Args: data: Encrypted bytes to decrypt @@ -76,6 +72,7 @@ def decrypt(self, data: bytes) -> bytes: Raises: SDKException: If decryption fails or key is not set + """ if self.private_key is None: raise SDKException("Failed to decrypt, private key is empty") @@ -93,8 +90,7 @@ def decrypt(self, data: bytes) -> bytes: class AsymEncryption: - """ - Provides functionality for asymmetric encryption using an RSA public key or certificate in PEM format. + """Provides functionality for asymmetric encryption using an RSA public key or certificate in PEM format. Supports PEM public keys, X.509 certificates, and pre-loaded key objects. Also handles base64-encoded keys without PEM headers. @@ -105,8 +101,7 @@ class AsymEncryption: CIPHER_TRANSFORM = "RSA/ECB/OAEPWithSHA-1AndMGF1Padding" def __init__(self, public_key_pem: str | None = None, public_key_obj=None): - """ - Initialize with either a PEM string or a key object. + """Initialize with either a PEM string or a key object. Args: public_key_pem: Public key in PEM format, X.509 certificate, or base64 string @@ -114,6 +109,7 @@ def __init__(self, public_key_pem: str | None = None, public_key_obj=None): Raises: SDKException: If key loading fails or key is not RSA + """ if public_key_obj is not None: self.public_key = public_key_obj @@ -152,8 +148,7 @@ def __init__(self, public_key_pem: str | None = None, public_key_obj=None): raise SDKException("Not an RSA PEM formatted public key") def encrypt(self, data: bytes) -> bytes: - """ - Encrypt data using RSA OAEP with SHA-1. + """Encrypt data using RSA OAEP with SHA-1. Args: data: Plaintext bytes to encrypt @@ -163,6 +158,7 @@ def encrypt(self, data: bytes) -> bytes: Raises: SDKException: If encryption fails or key is not set + """ if self.public_key is None: raise SDKException("Failed to encrypt, public key is empty") @@ -179,14 +175,14 @@ def encrypt(self, data: bytes) -> bytes: raise SDKException(f"Error performing encryption: {e}") from e def public_key_in_pem_format(self) -> str: - """ - Export the public key to PEM format. + """Export the public key to PEM format. Returns: Public key as PEM-encoded string Raises: SDKException: If export fails + """ try: pem = self.public_key.public_bytes( diff --git a/src/otdf_python/auth_headers.py b/src/otdf_python/auth_headers.py index 830c144..afb2077 100644 --- a/src/otdf_python/auth_headers.py +++ b/src/otdf_python/auth_headers.py @@ -1,10 +1,11 @@ +"""Authentication header management.""" + from dataclasses import dataclass @dataclass class AuthHeaders: - """ - Represents authentication headers used in token-based authorization. + """Represents authentication headers used in token-based authorization. This class holds authorization and DPoP (Demonstrating Proof of Possession) headers that are used in token-based API requests. """ @@ -13,19 +14,19 @@ class AuthHeaders: dpop_header: str = "" def get_auth_header(self) -> str: - """Returns the authorization header.""" + """Get the authorization header.""" return self.auth_header def get_dpop_header(self) -> str: - """Returns the DPoP header.""" + """Get the DPoP header.""" return self.dpop_header def to_dict(self) -> dict[str, str]: - """ - Convert authentication headers to a dictionary for use with HTTP clients. + """Convert authentication headers to a dictionary for use with HTTP clients. Returns: Dictionary with 'Authorization' header and optionally 'DPoP' header + """ headers = {"Authorization": self.auth_header} if self.dpop_header: diff --git a/src/otdf_python/autoconfigure_utils.py b/src/otdf_python/autoconfigure_utils.py index a1e6c26..ad7a04a 100644 --- a/src/otdf_python/autoconfigure_utils.py +++ b/src/otdf_python/autoconfigure_utils.py @@ -1,3 +1,5 @@ +"""Utilities for automatic SDK configuration.""" + import re import urllib.parse from dataclasses import dataclass @@ -6,6 +8,8 @@ # RuleType constants class RuleType: + """Rule type constants for attribute hierarchy.""" + HIERARCHY = "hierarchy" ALL_OF = "allOf" ANY_OF = "anyOf" @@ -15,6 +19,8 @@ class RuleType: @dataclass(frozen=True) class KeySplitStep: + """Key split step information.""" + kas: str splitID: str @@ -31,11 +37,16 @@ def __hash__(self): class AutoConfigureException(Exception): + """Exception for auto-configuration errors.""" + pass class AttributeNameFQN: + """Fully qualified attribute name.""" + def __init__(self, url: str): + """Initialize attribute name from URL.""" pattern = re.compile(r"^(https?://[\w./-]+)/attr/([^/\s]*)$") matcher = pattern.match(url) if not matcher or not matcher.group(1) or not matcher.group(2): @@ -81,7 +92,10 @@ def name(self): class AttributeValueFQN: + """Fully qualified attribute value.""" + def __init__(self, url: str): + """Initialize attribute value from URL.""" pattern = re.compile(r"^(https?://[\w./-]+)/attr/(\S*)/value/(\S*)$") matcher = pattern.match(url) if ( diff --git a/src/otdf_python/cli.py b/src/otdf_python/cli.py index 3c33e3c..8037a35 100644 --- a/src/otdf_python/cli.py +++ b/src/otdf_python/cli.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -""" -OpenTDF Python CLI +"""OpenTDF Python CLI. A command-line interface for encrypting and decrypting files using OpenTDF. Provides encrypt, decrypt, and inspect commands similar to the otdfctl CLI. @@ -36,6 +35,7 @@ class CLIError(Exception): """Custom exception for CLI errors.""" def __init__(self, level: str, message: str, cause: Exception | None = None): + """Initialize CLI error.""" self.level = level self.message = message self.cause = cause @@ -525,7 +525,7 @@ def create_parser() -> argparse.ArgumentParser: def main(): - """Main CLI entry point.""" + """Execute the CLI entry point.""" parser = create_parser() args = parser.parse_args() diff --git a/src/otdf_python/collection_store.py b/src/otdf_python/collection_store.py index 8716ff7..215dbd8 100644 --- a/src/otdf_python/collection_store.py +++ b/src/otdf_python/collection_store.py @@ -1,12 +1,19 @@ +"""Collection store interface for managing collections.""" + from collections import OrderedDict class CollectionKey: + """Collection key wrapper for store operations.""" + def __init__(self, key: bytes | None): + """Initialize collection key.""" self.key = key class CollectionStore: + """Abstract collection store interface for key management.""" + NO_PRIVATE_KEY = CollectionKey(None) def store(self, header, key: CollectionKey): @@ -17,7 +24,10 @@ def get_key(self, header) -> CollectionKey: class NoOpCollectionStore(CollectionStore): + """No-op collection store that discards all keys.""" + def store(self, header, key: CollectionKey): + """Discard key operation (no-op).""" pass def get_key(self, header) -> CollectionKey: @@ -25,9 +35,12 @@ def get_key(self, header) -> CollectionKey: class CollectionStoreImpl(OrderedDict, CollectionStore): + """Collection store implementation with ordered dictionary.""" + MAX_SIZE_STORE = 500 def __init__(self): + """Initialize collection store.""" super().__init__() def store(self, header, key: CollectionKey): diff --git a/src/otdf_python/collection_store_impl.py b/src/otdf_python/collection_store_impl.py index 9b25042..e63d6a4 100644 --- a/src/otdf_python/collection_store_impl.py +++ b/src/otdf_python/collection_store_impl.py @@ -1,3 +1,5 @@ +"""Collection store implementation.""" + from collections import OrderedDict from threading import RLock @@ -5,7 +7,10 @@ class CollectionStoreImpl(OrderedDict): + """Thread-safe collection store for caching TDF keys.""" + def __init__(self): + """Initialize collection store.""" super().__init__() self._lock = RLock() diff --git a/src/otdf_python/config.py b/src/otdf_python/config.py index 646acec..f8c7252 100644 --- a/src/otdf_python/config.py +++ b/src/otdf_python/config.py @@ -1,3 +1,5 @@ +"""Configuration classes for TDF and NanoTDF operations.""" + from dataclasses import dataclass, field from enum import Enum from typing import Any @@ -5,17 +7,23 @@ class TDFFormat(Enum): + """TDF format enumeration.""" + JSONFormat = "JSONFormat" XMLFormat = "XMLFormat" class IntegrityAlgorithm(Enum): + """Integrity algorithm enumeration.""" + HS256 = "HS256" GMAC = "GMAC" @dataclass class KASInfo: + """Key Access Service information.""" + url: str public_key: str | None = None kid: str | None = None @@ -28,6 +36,8 @@ def __str__(self): @dataclass class TDFConfig: + """TDF encryption configuration.""" + autoconfigure: bool = True default_segment_size: int = 2 * 1024 * 1024 enable_encryption: bool = True @@ -49,6 +59,8 @@ class TDFConfig: @dataclass class NanoTDFConfig: + """NanoTDF encryption configuration.""" + ecc_mode: str | None = None cipher: str | None = None config: str | None = None @@ -60,6 +72,7 @@ class NanoTDFConfig: # Utility function to normalize KAS URLs (Python equivalent) def get_kas_address(kas_url: str) -> str: + """Normalize KAS URL by adding https:// if no scheme present.""" if "://" not in kas_url: kas_url = "https://" + kas_url parsed = urlparse(kas_url) diff --git a/src/otdf_python/connect_client.py b/src/otdf_python/connect_client.py index e69de29..a77b5f8 100644 --- a/src/otdf_python/connect_client.py +++ b/src/otdf_python/connect_client.py @@ -0,0 +1 @@ +"""Connect RPC client for KAS operations.""" diff --git a/src/otdf_python/constants.py b/src/otdf_python/constants.py index 048e914..cce26cb 100644 --- a/src/otdf_python/constants.py +++ b/src/otdf_python/constants.py @@ -1 +1,3 @@ +"""Application constants and default values.""" + MAGIC_NUMBER_AND_VERSION = bytes([0x4C, 0x31, 0x4C]) diff --git a/src/otdf_python/crypto_utils.py b/src/otdf_python/crypto_utils.py index b32a5e9..ab29c5a 100644 --- a/src/otdf_python/crypto_utils.py +++ b/src/otdf_python/crypto_utils.py @@ -1,3 +1,5 @@ +"""Cryptographic utility functions.""" + import hashlib import hmac @@ -7,6 +9,8 @@ class CryptoUtils: + """Cryptographic utility functions and helpers.""" + KEYPAIR_SIZE = 2048 @staticmethod diff --git a/src/otdf_python/dpop.py b/src/otdf_python/dpop.py index c442a5e..7102fcf 100644 --- a/src/otdf_python/dpop.py +++ b/src/otdf_python/dpop.py @@ -1,6 +1,4 @@ -""" -DPoP (Demonstration of Proof-of-Possession) token generation utilities. -""" +"""DPoP (Demonstration of Proof-of-Possession) token generation utilities.""" import base64 import hashlib @@ -18,8 +16,7 @@ def create_dpop_token( method: str = "POST", access_token: str | None = None, ) -> str: - """ - Create a DPoP (Demonstration of Proof-of-Possession) token. + """Create a DPoP (Demonstration of Proof-of-Possession) token. Args: private_key_pem: RSA private key in PEM format for signing @@ -30,6 +27,7 @@ def create_dpop_token( Returns: DPoP token as a string + """ # Parse the RSA public key to extract modulus and exponent public_key_obj = CryptoUtils.get_rsa_public_key_from_pem(public_key_pem) diff --git a/src/otdf_python/ecc_constants.py b/src/otdf_python/ecc_constants.py index 1d8d409..37c5f49 100644 --- a/src/otdf_python/ecc_constants.py +++ b/src/otdf_python/ecc_constants.py @@ -1,5 +1,4 @@ -""" -Elliptic Curve Constants for NanoTDF. +"""Elliptic Curve Constants for NanoTDF. This module defines shared constants for elliptic curve operations used across the SDK, particularly for NanoTDF encryption/decryption. @@ -14,8 +13,7 @@ class ECCConstants: - """ - Centralized constants for elliptic curve cryptography operations. + """Centralized constants for elliptic curve cryptography operations. This class provides mappings between curve names, curve type integers, cryptography curve objects, and compressed public key sizes. @@ -67,8 +65,7 @@ class ECCConstants: @classmethod def get_curve_name(cls, curve_type: int) -> str: - """ - Get curve name from curve type integer. + """Get curve name from curve type integer. Args: curve_type: Curve type (0=secp256r1, 1=secp384r1, 2=secp521r1, 3=secp256k1) @@ -78,6 +75,7 @@ def get_curve_name(cls, curve_type: int) -> str: Raises: ValueError: If curve_type is not supported + """ name = cls.CURVE_TYPE_TO_NAME.get(curve_type) if name is None: @@ -89,8 +87,7 @@ def get_curve_name(cls, curve_type: int) -> str: @classmethod def get_curve_type(cls, curve_name: str) -> int: - """ - Get curve type integer from curve name. + """Get curve type integer from curve name. Args: curve_name: Curve name (e.g., "secp256r1") @@ -100,6 +97,7 @@ def get_curve_type(cls, curve_name: str) -> int: Raises: ValueError: If curve_name is not supported + """ curve_type = cls.CURVE_NAME_TO_TYPE.get(curve_name.lower()) if curve_type is None: @@ -111,8 +109,7 @@ def get_curve_type(cls, curve_name: str) -> int: @classmethod def get_compressed_key_size_by_type(cls, curve_type: int) -> int: - """ - Get compressed public key size from curve type integer. + """Get compressed public key size from curve type integer. Args: curve_type: Curve type (0=secp256r1, 1=secp384r1, 2=secp521r1, 3=secp256k1) @@ -122,6 +119,7 @@ def get_compressed_key_size_by_type(cls, curve_type: int) -> int: Raises: ValueError: If curve_type is not supported + """ size = cls.COMPRESSED_KEY_SIZE_BY_TYPE.get(curve_type) if size is None: @@ -133,8 +131,7 @@ def get_compressed_key_size_by_type(cls, curve_type: int) -> int: @classmethod def get_compressed_key_size_by_name(cls, curve_name: str) -> int: - """ - Get compressed public key size from curve name. + """Get compressed public key size from curve name. Args: curve_name: Curve name (e.g., "secp256r1") @@ -144,6 +141,7 @@ def get_compressed_key_size_by_name(cls, curve_name: str) -> int: Raises: ValueError: If curve_name is not supported + """ size = cls.COMPRESSED_KEY_SIZE_BY_NAME.get(curve_name.lower()) if size is None: @@ -155,8 +153,7 @@ def get_compressed_key_size_by_name(cls, curve_name: str) -> int: @classmethod def get_curve_object(cls, curve_name: str) -> ec.EllipticCurve: - """ - Get cryptography library curve object from curve name. + """Get cryptography library curve object from curve name. Args: curve_name: Curve name (e.g., "secp256r1") @@ -166,6 +163,7 @@ def get_curve_object(cls, curve_name: str) -> ec.EllipticCurve: Raises: ValueError: If curve_name is not supported + """ curve = cls.CURVE_OBJECTS.get(curve_name.lower()) if curve is None: diff --git a/src/otdf_python/ecc_mode.py b/src/otdf_python/ecc_mode.py index 3966e9d..1edca4b 100644 --- a/src/otdf_python/ecc_mode.py +++ b/src/otdf_python/ecc_mode.py @@ -1,9 +1,10 @@ +"""Elliptic Curve Cryptography mode enumeration.""" + from otdf_python.ecc_constants import ECCConstants class ECCMode: - """ - ECC (Elliptic Curve Cryptography) mode configuration for NanoTDF. + """ECC (Elliptic Curve Cryptography) mode configuration for NanoTDF. This class encapsulates the curve type and policy binding mode (GMAC vs ECDSA) that are encoded in the NanoTDF header. It delegates to ECCConstants for @@ -11,6 +12,7 @@ class ECCMode: """ def __init__(self, curve_mode: int = 0, use_ecdsa_binding: bool = False): + """Initialize ECC mode.""" self.curve_mode = curve_mode self.use_ecdsa_binding = use_ecdsa_binding @@ -34,6 +36,7 @@ def get_curve_name(self) -> str: Raises: ValueError: If curve_mode is not supported + """ # Delegate to ECCConstants for the authoritative mapping return ECCConstants.get_curve_name(self.curve_mode) @@ -50,6 +53,7 @@ def get_ec_compressed_pubkey_size(curve_type: int) -> int: Raises: ValueError: If curve_type is not supported + """ # Delegate to ECCConstants for the authoritative mapping return ECCConstants.get_compressed_key_size_by_type(curve_type) @@ -71,6 +75,7 @@ def from_string(curve_str: str) -> "ECCMode": Raises: ValueError: If curve_str is not a supported curve or binding type + """ # Handle policy binding types (always use secp256r1 as default curve) if curve_str.lower() == "gmac": diff --git a/src/otdf_python/ecdh.py b/src/otdf_python/ecdh.py index 63da6ed..eb093c1 100644 --- a/src/otdf_python/ecdh.py +++ b/src/otdf_python/ecdh.py @@ -1,5 +1,4 @@ -""" -ECDH (Elliptic Curve Diffie-Hellman) key exchange for NanoTDF. +"""ECDH (Elliptic Curve Diffie-Hellman) key exchange for NanoTDF. This module implements the ECDH key exchange protocol with HKDF key derivation as specified in the NanoTDF spec. It supports the following curves: @@ -52,8 +51,7 @@ class InvalidKeyError(ECDHError): def get_curve(curve_name: str) -> ec.EllipticCurve: - """ - Get the cryptography curve object for a given curve name. + """Get the cryptography curve object for a given curve name. Args: curve_name: Name of the curve (e.g., "secp256r1") @@ -63,6 +61,7 @@ def get_curve(curve_name: str) -> ec.EllipticCurve: Raises: UnsupportedCurveError: If the curve is not supported + """ try: # Delegate to ECCConstants for the authoritative mapping @@ -72,8 +71,7 @@ def get_curve(curve_name: str) -> ec.EllipticCurve: def get_compressed_key_size(curve_name: str) -> int: - """ - Get the size of a compressed public key for a given curve. + """Get the size of a compressed public key for a given curve. Args: curve_name: Name of the curve (e.g., "secp256r1") @@ -83,6 +81,7 @@ def get_compressed_key_size(curve_name: str) -> int: Raises: UnsupportedCurveError: If the curve is not supported + """ try: # Delegate to ECCConstants for the authoritative mapping @@ -94,8 +93,7 @@ def get_compressed_key_size(curve_name: str) -> int: def generate_ephemeral_keypair( curve_name: str, ) -> tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]: - """ - Generate an ephemeral keypair for ECDH. + """Generate an ephemeral keypair for ECDH. Args: curve_name: Name of the curve (e.g., "secp256r1") @@ -105,6 +103,7 @@ def generate_ephemeral_keypair( Raises: UnsupportedCurveError: If the curve is not supported + """ curve = get_curve(curve_name) private_key = ec.generate_private_key(curve, default_backend()) @@ -113,14 +112,14 @@ def generate_ephemeral_keypair( def compress_public_key(public_key: ec.EllipticCurvePublicKey) -> bytes: - """ - Compress an EC public key to compressed point format. + """Compress an EC public key to compressed point format. Args: public_key: The EC public key to compress Returns: bytes: Compressed public key (33-67 bytes depending on curve) + """ return public_key.public_bytes( encoding=Encoding.X962, format=PublicFormat.CompressedPoint @@ -130,8 +129,7 @@ def compress_public_key(public_key: ec.EllipticCurvePublicKey) -> bytes: def decompress_public_key( compressed_key: bytes, curve_name: str ) -> ec.EllipticCurvePublicKey: - """ - Decompress a public key from compressed point format. + """Decompress a public key from compressed point format. Args: compressed_key: The compressed public key bytes @@ -143,6 +141,7 @@ def decompress_public_key( Raises: InvalidKeyError: If the key cannot be decompressed UnsupportedCurveError: If the curve is not supported + """ try: curve = get_curve(curve_name) @@ -162,8 +161,7 @@ def decompress_public_key( def derive_shared_secret( private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey ) -> bytes: - """ - Derive a shared secret using ECDH. + """Derive a shared secret using ECDH. Args: private_key: The private key (can be ephemeral or recipient's key) @@ -174,6 +172,7 @@ def derive_shared_secret( Raises: ECDHError: If ECDH fails + """ try: shared_secret = private_key.exchange(ec.ECDH(), public_key) @@ -188,8 +187,7 @@ def derive_key_from_shared_secret( salt: bytes | None = None, info: bytes = b"", ) -> bytes: - """ - Derive a symmetric encryption key from the ECDH shared secret using HKDF. + """Derive a symmetric encryption key from the ECDH shared secret using HKDF. Args: shared_secret: The raw ECDH shared secret @@ -202,6 +200,7 @@ def derive_key_from_shared_secret( Raises: ECDHError: If key derivation fails + """ if salt is None: salt = NANOTDF_HKDF_SALT @@ -222,8 +221,7 @@ def derive_key_from_shared_secret( def encrypt_key_with_ecdh( recipient_public_key_pem: str, curve_name: str = "secp256r1" ) -> tuple[bytes, bytes]: - """ - High-level function: Generate ephemeral keypair and derive encryption key. + """High-level function: Generate ephemeral keypair and derive encryption key. This is used during NanoTDF encryption to derive the key that will be used to encrypt the payload. The ephemeral public key must be stored in the @@ -242,6 +240,7 @@ def encrypt_key_with_ecdh( ECDHError: If key derivation fails InvalidKeyError: If recipient's public key is invalid UnsupportedCurveError: If the curve is not supported + """ # Load recipient's public key try: @@ -273,8 +272,7 @@ def decrypt_key_with_ecdh( compressed_ephemeral_public_key: bytes, curve_name: str = "secp256r1", ) -> bytes: - """ - High-level function: Derive decryption key from ephemeral public key and recipient's private key. + """High-level function: Derive decryption key from ephemeral public key and recipient's private key. This is used during NanoTDF decryption to derive the same key that was used to encrypt the payload. The ephemeral public key is extracted from the @@ -292,6 +290,7 @@ def decrypt_key_with_ecdh( ECDHError: If key derivation fails InvalidKeyError: If keys are invalid UnsupportedCurveError: If the curve is not supported + """ # Load recipient's private key try: diff --git a/src/otdf_python/eckeypair.py b/src/otdf_python/eckeypair.py index 3dee0aa..5cf1908 100644 --- a/src/otdf_python/eckeypair.py +++ b/src/otdf_python/eckeypair.py @@ -1,3 +1,5 @@ +"""Elliptic Curve key pair management.""" + from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization @@ -12,7 +14,10 @@ class ECKeyPair: + """Elliptic Curve key pair for cryptographic operations.""" + def __init__(self, curve=None): + """Initialize EC key pair.""" if curve is None: curve = ec.SECP256R1() self.private_key = ec.generate_private_key(curve, default_backend()) diff --git a/src/otdf_python/header.py b/src/otdf_python/header.py index 42f82ac..2379b44 100644 --- a/src/otdf_python/header.py +++ b/src/otdf_python/header.py @@ -1,3 +1,5 @@ +"""TDF header parsing and serialization.""" + from otdf_python.constants import MAGIC_NUMBER_AND_VERSION from otdf_python.ecc_mode import ECCMode from otdf_python.policy_info import PolicyInfo @@ -6,10 +8,13 @@ class Header: + """TDF header with encryption and policy information.""" + # Size of GMAC (Galois Message Authentication Code) for policy binding GMAC_SIZE = 8 def __init__(self): + """Initialize TDF header.""" self.kas_locator: ResourceLocator | None = None self.ecc_mode: ECCMode | None = None self.payload_config: SymmetricAndPayloadConfig | None = None diff --git a/src/otdf_python/invalid_zip_exception.py b/src/otdf_python/invalid_zip_exception.py index 7ae67ad..cc18b63 100644 --- a/src/otdf_python/invalid_zip_exception.py +++ b/src/otdf_python/invalid_zip_exception.py @@ -1,8 +1,12 @@ +"""Exception for invalid ZIP file errors.""" + + class InvalidZipException(Exception): - """ - Raised when a ZIP file is invalid or corrupted. + """Raised when a ZIP file is invalid or corrupted. + Based on Java implementation. """ def __init__(self, message: str): + """Initialize exception.""" super().__init__(message) diff --git a/src/otdf_python/kas_client.py b/src/otdf_python/kas_client.py index 1b4a669..f0400ba 100644 --- a/src/otdf_python/kas_client.py +++ b/src/otdf_python/kas_client.py @@ -1,6 +1,4 @@ -""" -KASClient: Handles communication with the Key Access Service (KAS). -""" +"""KASClient: Handles communication with the Key Access Service (KAS).""" import base64 import hashlib @@ -22,6 +20,8 @@ @dataclass class KeyAccess: + """Key access response from KAS.""" + url: str wrapped_key: str ephemeral_public_key: str | None = None @@ -29,6 +29,8 @@ class KeyAccess: class KASClient: + """Client for communicating with the Key Access Service (KAS).""" + def __init__( self, kas_url=None, @@ -37,6 +39,7 @@ def __init__( use_plaintext=False, verify_ssl=True, ): + """Initialize KAS client.""" self.kas_url = kas_url self.token_source = token_source self.cache = cache or KASKeyCache() @@ -63,14 +66,14 @@ def __init__( ) def _normalize_kas_url(self, url: str) -> str: - """ - Normalize KAS URLs based on client security settings. + """Normalize KAS URLs based on client security settings. Args: url: The KAS URL to normalize Returns: Normalized URL with appropriate protocol and port + """ from urllib.parse import urlparse @@ -141,14 +144,14 @@ def _handle_existing_scheme(self, parsed) -> str: raise SDKException(f"error creating KAS address: {e}") from e def _get_wrapped_key_base64(self, key_access): - """ - Extract and normalize the wrapped key to base64-encoded string. + """Extract and normalize the wrapped key to base64-encoded string. Args: key_access: KeyAccess object Returns: Base64-encoded wrapped key string + """ wrapped_key = getattr(key_access, "wrappedKey", None) or getattr( key_access, "wrapped_key", None @@ -166,14 +169,14 @@ def _get_wrapped_key_base64(self, key_access): return wrapped_key def _build_key_access_dict(self, key_access): - """ - Build key access dictionary from KeyAccess object, handling both old and new field names. + """Build key access dictionary from KeyAccess object, handling both old and new field names. Args: key_access: KeyAccess object Returns: Dictionary with key access information + """ wrapped_key = self._get_wrapped_key_base64(key_access) @@ -197,12 +200,12 @@ def _build_key_access_dict(self, key_access): return key_access_dict def _add_optional_fields(self, key_access_dict, key_access): - """ - Add optional fields to key access dictionary. + """Add optional fields to key access dictionary. Args: key_access_dict: Dictionary to add fields to key_access: KeyAccess object to extract fields from + """ # Policy binding policy_binding = getattr(key_access, "policyBinding", None) or getattr( @@ -244,14 +247,14 @@ def _add_optional_fields(self, key_access_dict, key_access): key_access_dict["header"] = base64.b64encode(header).decode("utf-8") def _get_algorithm_from_session_key_type(self, session_key_type): - """ - Convert session key type to algorithm string for KAS. + """Convert session key type to algorithm string for KAS. Args: session_key_type: Session key type (EC_KEY_TYPE or RSA_KEY_TYPE) Returns: Algorithm string or None + """ if session_key_type == EC_KEY_TYPE: return "ec:secp256r1" # Default EC curve for NanoTDF @@ -262,8 +265,7 @@ def _get_algorithm_from_session_key_type(self, session_key_type): def _build_rewrap_request( self, policy_json, client_public_key, key_access_dict, algorithm, has_header ): - """ - Build the unsigned rewrap request structure. + """Build the unsigned rewrap request structure. Args: policy_json: Policy JSON string @@ -274,6 +276,7 @@ def _build_rewrap_request( Returns: Dictionary with unsigned rewrap request + """ import json @@ -316,8 +319,7 @@ def _build_rewrap_request( def _create_signed_request_jwt( self, policy_json, client_public_key, key_access, session_key_type=None ): - """ - Create a signed JWT for the rewrap request. + """Create a signed JWT for the rewrap request. The JWT is signed with the DPoP private key. Args: @@ -325,6 +327,7 @@ def _create_signed_request_jwt( client_public_key: Client public key PEM string key_access: KeyAccess object session_key_type: Optional session key type (RSA_KEY_TYPE or EC_KEY_TYPE) + """ # Build key access dictionary handling both old and new field names key_access_dict = self._build_key_access_dict(key_access) @@ -354,8 +357,7 @@ def _create_signed_request_jwt( return jwt.encode(payload, self._dpop_private_key_pem, algorithm="RS256") def _create_connect_rpc_signed_token(self, key_access, policy_json): - """ - Create a signed token specifically for Connect RPC requests. + """Create a signed token specifically for Connect RPC requests. For now, this delegates to the existing JWT creation method. """ return self._create_signed_request_jwt( @@ -363,8 +365,7 @@ def _create_connect_rpc_signed_token(self, key_access, policy_json): ) def _create_dpop_proof(self, method, url, access_token=None): - """ - Create a DPoP proof JWT as per RFC 9449. + """Create a DPoP proof JWT as per RFC 9449. Args: method: HTTP method (e.g., "POST") @@ -373,6 +374,7 @@ def _create_dpop_proof(self, method, url, access_token=None): Returns: DPoP proof JWT string + """ now = int(time.time()) @@ -424,8 +426,7 @@ def _create_dpop_proof(self, method, url, access_token=None): ) def get_public_key(self, kas_info): - """ - Get KAS public key using Connect RPC. + """Get KAS public key using Connect RPC. Checks cache first if available. """ try: @@ -448,10 +449,7 @@ def get_public_key(self, kas_info): raise def _get_public_key_with_connect_rpc(self, kas_info): - """ - Get KAS public key using Connect RPC. - """ - + """Get KAS public key using Connect RPC.""" # Get access token for authentication if token source is available access_token = None if self.token_source: @@ -486,14 +484,14 @@ def _get_public_key_with_connect_rpc(self, kas_info): raise SDKException(f"Connect RPC public key request failed: {e}") from e def _normalize_session_key_type(self, session_key_type): - """ - Normalize session key type to the appropriate enum value. + """Normalize session key type to the appropriate enum value. Args: session_key_type: Type of the session key (KeyType enum or string "RSA"/"EC") Returns: Normalized key type enum + """ if isinstance(session_key_type, str): if session_key_type.upper() == "RSA": @@ -511,14 +509,14 @@ def _normalize_session_key_type(self, session_key_type): return session_key_type def _prepare_ec_keypair(self, session_key_type): - """ - Prepare EC key pair for unwrapping. + """Prepare EC key pair for unwrapping. Args: session_key_type: EC key type with curve information Returns: ECKeyPair instance and client public key + """ from .eckeypair import ECKeyPair @@ -528,12 +526,12 @@ def _prepare_ec_keypair(self, session_key_type): return ec_key_pair, client_public_key def _prepare_rsa_keypair(self): - """ - Prepare RSA key pair for unwrapping, reusing if possible. + """Prepare RSA key pair for unwrapping, reusing if possible. Uses separate ephemeral keys for encryption (not DPoP keys). Returns: Client public key PEM for the ephemeral encryption key + """ if self.decryptor is None: # Generate ephemeral keys for encryption (separate from DPoP keys) @@ -543,8 +541,7 @@ def _prepare_rsa_keypair(self): return self.client_public_key def _unwrap_with_ec(self, wrapped_key, ec_key_pair, response_data): - """ - Unwrap a key using EC cryptography. + """Unwrap a key using EC cryptography. Args: wrapped_key: The wrapped key to decrypt @@ -553,6 +550,7 @@ def _unwrap_with_ec(self, wrapped_key, ec_key_pair, response_data): Returns: Unwrapped key as bytes + """ if ec_key_pair is None: raise SDKException( @@ -581,9 +579,7 @@ def _unwrap_with_ec(self, wrapped_key, ec_key_pair, response_data): return gcm.decrypt(wrapped_key) def _ensure_client_keypair(self, session_key_type): - """ - Ensure client keypair is generated and stored. - """ + """Ensure client keypair is generated and stored.""" if session_key_type == RSA_KEY_TYPE: if self.decryptor is None: private_key, public_key = CryptoUtils.generate_rsa_keypair() @@ -600,9 +596,7 @@ def _ensure_client_keypair(self, session_key_type): self.client_public_key = CryptoUtils.get_rsa_public_key_pem(public_key) def _parse_and_decrypt_response(self, response): - """ - Parse JSON response and decrypt the wrapped key. - """ + """Parse JSON response and decrypt the wrapped key.""" try: response_data = response.json() except Exception as e: @@ -621,8 +615,7 @@ def _parse_and_decrypt_response(self, response): return self.decryptor.decrypt(encrypted_key) def unwrap(self, key_access, policy_json, session_key_type=None) -> bytes: - """ - Unwrap a key using Connect RPC. + """Unwrap a key using Connect RPC. Args: key_access: Key access information @@ -631,6 +624,7 @@ def unwrap(self, key_access, policy_json, session_key_type=None) -> bytes: Returns: Unwrapped key bytes + """ # Default to RSA if not specified if session_key_type is None: @@ -655,15 +649,14 @@ def unwrap(self, key_access, policy_json, session_key_type=None) -> bytes: def _unwrap_with_connect_rpc( self, key_access, signed_token, session_key_type=None ) -> bytes: - """ - Connect RPC method for unwrapping keys. + """Connect RPC method for unwrapping keys. Args: key_access: KeyAccess object signed_token: Signed JWT token session_key_type: Optional session key type (RSA_KEY_TYPE or EC_KEY_TYPE) - """ + """ # Get access token for authentication if token source is available access_token = None if self.token_source: @@ -705,5 +698,5 @@ def _unwrap_with_connect_rpc( raise SDKException(f"Connect RPC rewrap failed: {e}") from e def get_key_cache(self) -> KASKeyCache: - """Returns the KAS key cache used for storing and retrieving encryption keys.""" + """Return the KAS key cache used for storing and retrieving encryption keys.""" return self.cache diff --git a/src/otdf_python/kas_connect_rpc_client.py b/src/otdf_python/kas_connect_rpc_client.py index 9589f92..50397d0 100644 --- a/src/otdf_python/kas_connect_rpc_client.py +++ b/src/otdf_python/kas_connect_rpc_client.py @@ -1,5 +1,4 @@ -""" -KASConnectRPCClient: Handles Connect RPC communication with the Key Access Service (KAS). +"""KASConnectRPCClient: Handles Connect RPC communication with the Key Access Service (KAS). This class encapsulates all interactions with otdf_python_proto. """ @@ -15,27 +14,25 @@ class KASConnectRPCClient: - """ - Handles Connect RPC communication with KAS service using otdf_python_proto. - """ + """Handles Connect RPC communication with KAS service using otdf_python_proto.""" def __init__(self, use_plaintext=False, verify_ssl=True): - """ - Initialize the Connect RPC client. + """Initialize the Connect RPC client. Args: use_plaintext: Whether to use plaintext (HTTP) connections verify_ssl: Whether to verify SSL certificates + """ self.use_plaintext = use_plaintext self.verify_ssl = verify_ssl def _create_http_client(self): - """ - Create HTTP client with SSL verification configuration. + """Create HTTP client with SSL verification configuration. Returns: urllib3.PoolManager configured for SSL verification settings + """ if self.verify_ssl: logging.info("Using SSL verification enabled HTTP client") @@ -46,14 +43,14 @@ def _create_http_client(self): return urllib3.PoolManager(cert_reqs="CERT_NONE") def _prepare_connect_rpc_url(self, kas_url): - """ - Prepare the base URL for Connect RPC client. + """Prepare the base URL for Connect RPC client. Args: kas_url: The normalized KAS URL Returns: Base URL for Connect RPC client (without /kas suffix) + """ connect_rpc_base_url = kas_url # Remove /kas suffix, if present @@ -61,14 +58,14 @@ def _prepare_connect_rpc_url(self, kas_url): return connect_rpc_base_url def _prepare_auth_headers(self, access_token): - """ - Prepare authentication headers if access token is available. + """Prepare authentication headers if access token is available. Args: access_token: Bearer token for authentication Returns: Dictionary with authentication headers or None + """ if access_token: auth_headers = AuthHeaders( @@ -79,8 +76,7 @@ def _prepare_auth_headers(self, access_token): return None def get_public_key(self, normalized_kas_url, kas_info, access_token=None): - """ - Get KAS public key using Connect RPC. + """Get KAS public key using Connect RPC. Args: normalized_kas_url: The normalized KAS URL @@ -89,6 +85,7 @@ def get_public_key(self, normalized_kas_url, kas_info, access_token=None): Returns: Updated kas_info with public_key and kid + """ logging.info( f"KAS Connect RPC client settings for public key retrieval: " @@ -143,8 +140,7 @@ def get_public_key(self, normalized_kas_url, kas_info, access_token=None): def unwrap_key( self, normalized_kas_url, key_access, signed_token, access_token=None ): - """ - Unwrap a key using Connect RPC. + """Unwrap a key using Connect RPC. Args: normalized_kas_url: The normalized KAS URL @@ -154,6 +150,7 @@ def unwrap_key( Returns: Unwrapped key bytes from the response + """ logging.info( f"KAS Connect RPC client settings for unwrap: " diff --git a/src/otdf_python/kas_info.py b/src/otdf_python/kas_info.py index 189ec73..336359a 100644 --- a/src/otdf_python/kas_info.py +++ b/src/otdf_python/kas_info.py @@ -1,10 +1,11 @@ +"""Key Access Service information and configuration.""" + from dataclasses import dataclass @dataclass class KASInfo: - """ - Configuration for Key Access Server (KAS) information. + """Configuration for Key Access Server (KAS) information. This class stores details about a Key Access Server including its URL, public key, key ID, default status, and cryptographic algorithm. """ @@ -16,7 +17,7 @@ class KASInfo: algorithm: str | None = None def clone(self): - """Creates a copy of this KASInfo object.""" + """Create a copy of this KASInfo object.""" from copy import copy return copy(self) diff --git a/src/otdf_python/kas_key_cache.py b/src/otdf_python/kas_key_cache.py index 38e1f22..e1a2a6a 100644 --- a/src/otdf_python/kas_key_cache.py +++ b/src/otdf_python/kas_key_cache.py @@ -1,19 +1,19 @@ -""" -KASKeyCache: In-memory cache for KAS (Key Access Service) public keys and info. -""" +"""KASKeyCache: In-memory cache for KAS (Key Access Service) public keys and info.""" import threading from typing import Any class KASKeyCache: + """In-memory cache for KAS public keys and information.""" + def __init__(self): + """Initialize KAS key cache.""" self._cache = {} self._lock = threading.Lock() def get(self, url: str, algorithm: str | None = None) -> Any | None: - """ - Gets a KASInfo object from the cache based on URL and algorithm. + """Get a KASInfo object from cache based on URL and algorithm. Args: url: The URL of the KAS @@ -21,17 +21,18 @@ def get(self, url: str, algorithm: str | None = None) -> Any | None: Returns: The cached KASInfo object, or None if not found + """ cache_key = self._make_key(url, algorithm) with self._lock: return self._cache.get(cache_key) def store(self, kas_info) -> None: - """ - Stores a KASInfo object in the cache. + """Store a KASInfo object in cache. Args: kas_info: The KASInfo object to store + """ cache_key = self._make_key(kas_info.url, getattr(kas_info, "algorithm", None)) with self._lock: @@ -43,10 +44,10 @@ def set(self, key, value): self._cache[key] = value def clear(self): - """Clears the cache""" + """Clear the cache.""" with self._lock: self._cache.clear() def _make_key(self, url: str, algorithm: str | None = None) -> str: - """Creates a cache key from URL and algorithm""" + """Create a cache key from URL and algorithm.""" return f"{url}:{algorithm or ''}" diff --git a/src/otdf_python/key_type.py b/src/otdf_python/key_type.py index cb14f53..4dd688e 100644 --- a/src/otdf_python/key_type.py +++ b/src/otdf_python/key_type.py @@ -1,7 +1,11 @@ +"""Key type constants for RSA and EC encryption.""" + from enum import Enum class KeyType(Enum): + """Key type enumeration for encryption algorithms.""" + RSA2048Key = "rsa:2048" EC256Key = "ec:secp256r1" EC384Key = "ec:secp384r1" diff --git a/src/otdf_python/key_type_constants.py b/src/otdf_python/key_type_constants.py index a99da44..185cabb 100644 --- a/src/otdf_python/key_type_constants.py +++ b/src/otdf_python/key_type_constants.py @@ -1,5 +1,4 @@ -""" -Constants for session key types used in the KAS client. +"""Constants for session key types used in the KAS client. This matches the Java SDK's KeyType enum pattern. """ @@ -7,9 +6,7 @@ class KeyType(Enum): - """ - Enum for key types used in the KAS client. - """ + """Enum for key types used in the KAS client.""" RSA2048 = auto() EC_P256 = auto() @@ -18,16 +15,12 @@ class KeyType(Enum): @property def is_ec(self): - """ - Returns True if this key type is an EC key, False otherwise. - """ + """Returns True if this key type is an EC key, False otherwise.""" return self in [KeyType.EC_P256, KeyType.EC_P384, KeyType.EC_P521] @property def curve_name(self): - """ - Returns the curve name for EC keys. - """ + """Returns the curve name for EC keys.""" if self == KeyType.EC_P256: return "P-256" elif self == KeyType.EC_P384: diff --git a/src/otdf_python/manifest.py b/src/otdf_python/manifest.py index 1ebbae3..0cf7d4f 100644 --- a/src/otdf_python/manifest.py +++ b/src/otdf_python/manifest.py @@ -1,3 +1,5 @@ +"""TDF manifest representation and serialization.""" + import json from dataclasses import asdict, dataclass, field from typing import Any @@ -5,6 +7,8 @@ @dataclass class ManifestSegment: + """Encrypted segment information in TDF manifest.""" + hash: str segmentSize: int encryptedSegmentSize: int @@ -12,12 +16,16 @@ class ManifestSegment: @dataclass class ManifestRootSignature: + """Root signature for manifest integrity.""" + alg: str sig: str @dataclass class ManifestIntegrityInformation: + """Manifest integrity information with signatures and hashes.""" + rootSignature: ManifestRootSignature segmentHashAlg: str segmentSizeDefault: int @@ -27,12 +35,16 @@ class ManifestIntegrityInformation: @dataclass class ManifestPolicyBinding: + """Policy binding with algorithm and hash.""" + alg: str hash: str @dataclass class ManifestKeyAccess: + """Key access information in manifest.""" + type: str url: str protocol: str @@ -47,6 +59,8 @@ class ManifestKeyAccess: @dataclass class ManifestMethod: + """Encryption method information in manifest.""" + algorithm: str iv: str isStreamable: bool | None = None @@ -54,6 +68,8 @@ class ManifestMethod: @dataclass class ManifestEncryptionInformation: + """Encryption information in TDF manifest.""" + type: str policy: str keyAccess: list[ManifestKeyAccess] @@ -63,6 +79,8 @@ class ManifestEncryptionInformation: @dataclass class ManifestPayload: + """Payload information in TDF manifest.""" + type: str url: str protocol: str @@ -72,12 +90,16 @@ class ManifestPayload: @dataclass class ManifestBinding: + """Assertion binding information.""" + method: str signature: str @dataclass class ManifestAssertion: + """TDF assertion in manifest.""" + id: str type: str scope: str @@ -88,6 +110,8 @@ class ManifestAssertion: @dataclass class Manifest: + """TDF manifest with encryption and payload information.""" + schemaVersion: str | None = None encryptionInformation: ManifestEncryptionInformation | None = None payload: ManifestPayload | None = None diff --git a/src/otdf_python/nanotdf.py b/src/otdf_python/nanotdf.py index 40d0618..75671f1 100644 --- a/src/otdf_python/nanotdf.py +++ b/src/otdf_python/nanotdf.py @@ -1,3 +1,5 @@ +"""NanoTDF reader and writer implementation.""" + import contextlib import hashlib import json @@ -24,22 +26,32 @@ class NanoTDFException(SDKException): + """Base exception for NanoTDF operations.""" + pass class NanoTDFMaxSizeLimit(NanoTDFException): + """Exception for NanoTDF size limit exceeded.""" + pass class UnsupportedNanoTDFFeature(NanoTDFException): + """Exception for unsupported NanoTDF features.""" + pass class InvalidNanoTDFConfig(NanoTDFException): + """Exception for invalid NanoTDF configuration.""" + pass class NanoTDF: + """NanoTDF reader and writer for compact TDF format.""" + MAGIC_NUMBER_AND_VERSION = MAGIC_NUMBER_AND_VERSION K_MAX_TDF_SIZE = (16 * 1024 * 1024) - 3 - 32 K_NANOTDF_GMAC_LENGTH = 8 @@ -48,6 +60,7 @@ class NanoTDF: K_EMPTY_IV = bytes([0x0] * 12) def __init__(self, services=None, collection_store: CollectionStore | None = None): + """Initialize NanoTDF reader/writer.""" self.services = services self.collection_store = collection_store or NoOpCollectionStore() @@ -59,7 +72,7 @@ def _create_policy_object(self, attributes: list[str]) -> PolicyObject: return PolicyObject(uuid=policy_uuid, body=body) def _serialize_policy_object(self, obj): - """Custom NanoTDF serializer to convert to compatible JSON format.""" + """Serialize policy object to compatible JSON format.""" from otdf_python.policy_object import AttributeObject, PolicyBody if isinstance(obj, PolicyBody): @@ -82,8 +95,7 @@ def _serialize_policy_object(self, obj): return obj.__dict__ def _prepare_payload(self, payload: bytes | BytesIO) -> bytes: - """ - Convert BytesIO to bytes and validate payload size. + """Convert BytesIO to bytes and validate payload size. Args: payload: The payload data as bytes or BytesIO @@ -93,6 +105,7 @@ def _prepare_payload(self, payload: bytes | BytesIO) -> bytes: Raises: NanoTDFMaxSizeLimit: If the payload exceeds the maximum size + """ if isinstance(payload, BytesIO): payload = payload.getvalue() @@ -101,14 +114,14 @@ def _prepare_payload(self, payload: bytes | BytesIO) -> bytes: return payload def _prepare_policy_data(self, config: NanoTDFConfig) -> tuple[bytes, str]: - """ - Prepare policy data from configuration. + """Prepare policy data from configuration. Args: config: NanoTDFConfig configuration Returns: tuple: (policy_body, policy_type) + """ attributes = config.attributes if config.attributes else [] policy_object = self._create_policy_object(attributes) @@ -150,8 +163,7 @@ def _create_header( config: NanoTDFConfig, ephemeral_public_key: bytes | None = None, ) -> bytes: - """ - Create the NanoTDF header. + """Create the NanoTDF header. Args: policy_body: The policy body bytes @@ -161,6 +173,7 @@ def _create_header( Returns: bytes: The header bytes + """ from otdf_python.header import Header # Local import to avoid circular import @@ -228,8 +241,7 @@ def _create_header( return self.MAGIC_NUMBER_AND_VERSION + header_bytes def _is_ec_key(self, key_pem: str) -> bool: - """ - Detect if a PEM key is an EC key (vs RSA). + """Detect if a PEM key is an EC key (vs RSA). Args: key_pem: PEM-formatted key string @@ -239,6 +251,7 @@ def _is_ec_key(self, key_pem: str) -> bool: Raises: SDKException: If key cannot be parsed + """ try: # Try to load as public key first @@ -265,8 +278,7 @@ def _is_ec_key(self, key_pem: str) -> bool: def _derive_key_with_ecdh( # noqa: C901 self, config: NanoTDFConfig ) -> tuple[bytes, bytes | None, bytes | None]: - """ - Derive encryption key using ECDH if KAS public key is provided or can be fetched. + """Derive encryption key using ECDH if KAS public key is provided or can be fetched. This implements the NanoTDF spec's ECDH + HKDF key derivation: 1. Generate ephemeral keypair @@ -283,6 +295,7 @@ def _derive_key_with_ecdh( # noqa: C901 - derived_key: 32-byte AES-256 key for encrypting the payload - ephemeral_public_key_compressed: Compressed ephemeral public key to store in header (None for RSA) - kas_public_key: KAS public key PEM string (or None if not available) + """ import logging @@ -384,8 +397,7 @@ def _derive_key_with_ecdh( # noqa: C901 return derived_key, ephemeral_public_key_compressed, kas_public_key def _encrypt_payload(self, payload: bytes, key: bytes) -> tuple[bytes, bytes]: - """ - Encrypt the payload using AES-GCM. + """Encrypt the payload using AES-GCM. Args: payload: The payload to encrypt @@ -393,6 +405,7 @@ def _encrypt_payload(self, payload: bytes, key: bytes) -> tuple[bytes, bytes]: Returns: tuple: (iv, ciphertext) + """ iv = secrets.token_bytes(self.K_NANOTDF_IV_SIZE) iv_padded = self.K_EMPTY_IV[: self.K_IV_PADDING] + iv @@ -403,8 +416,7 @@ def _encrypt_payload(self, payload: bytes, key: bytes) -> tuple[bytes, bytes]: def create_nano_tdf( self, payload: bytes | BytesIO, output_stream: BinaryIO, config: NanoTDFConfig ) -> int: - """ - Stream-based NanoTDF creation - writes encrypted payload to an output stream. + """Stream-based NanoTDF creation - writes encrypted payload to an output stream. For convenience method that returns bytes, use create_nanotdf() instead. Supports ECDH key derivation if KAS info with public key is provided in config. @@ -422,8 +434,8 @@ def create_nano_tdf( UnsupportedNanoTDFFeature: If an unsupported feature is requested InvalidNanoTDFConfig: If the configuration is invalid SDKException: For other errors - """ + """ # Process payload and validate size payload = self._prepare_payload(payload) @@ -557,8 +569,7 @@ def read_nano_tdf( # noqa: C901 output_stream: BinaryIO, config: NanoTDFConfig, ) -> None: - """ - Stream-based NanoTDF decryption - writes decrypted payload to an output stream. + """Stream-based NanoTDF decryption - writes decrypted payload to an output stream. For convenience method that returns bytes, use read_nanotdf() instead. Supports ECDH key derivation and KAS key unwrapping. @@ -571,6 +582,7 @@ def read_nano_tdf( # noqa: C901 Raises: InvalidNanoTDFConfig: If the NanoTDF format is invalid or config is missing required info SDKException: For other errors + """ # Convert to bytes if BytesIO if isinstance(nano_tdf_data, BytesIO): @@ -768,8 +780,7 @@ def _handle_legacy_key_config( return key, config def create_nanotdf(self, data: bytes, config: dict | NanoTDFConfig) -> bytes: - """ - Convenience method - creates a NanoTDF and returns the encrypted bytes. + """Create a NanoTDF and return the encrypted bytes. For stream-based version, use create_nano_tdf() instead. """ @@ -846,8 +857,7 @@ def _extract_key_for_reading( def read_nanotdf( self, nanotdf_bytes: bytes, config: dict | NanoTDFConfig | None = None ) -> bytes: - """ - Convenience method - decrypts a NanoTDF and returns the plaintext bytes. + """Decrypt a NanoTDF and return the plaintext bytes. For stream-based version, use read_nano_tdf() instead. """ diff --git a/src/otdf_python/nanotdf_ecdsa_struct.py b/src/otdf_python/nanotdf_ecdsa_struct.py index da14939..1e9214c 100644 --- a/src/otdf_python/nanotdf_ecdsa_struct.py +++ b/src/otdf_python/nanotdf_ecdsa_struct.py @@ -1,6 +1,4 @@ -""" -NanoTDF ECDSA Signature Structure. -""" +"""NanoTDF ECDSA Signature Structure.""" from dataclasses import dataclass, field @@ -13,8 +11,7 @@ class IncorrectNanoTDFECDSASignatureSize(Exception): @dataclass class NanoTDFECDSAStruct: - """ - Class to handle ECDSA signature structure for NanoTDF. + """Class to handle ECDSA signature structure for NanoTDF. This structure represents an ECDSA signature as required by the NanoTDF format. It consists of r and s values along with their lengths. @@ -29,8 +26,7 @@ class NanoTDFECDSAStruct: def from_bytes( cls, ecdsa_signature_value: bytes, key_size: int ) -> "NanoTDFECDSAStruct": - """ - Create a NanoTDFECDSAStruct from a byte array. + """Create a NanoTDFECDSAStruct from a byte array. Args: ecdsa_signature_value: The signature value as bytes @@ -41,6 +37,7 @@ def from_bytes( Raises: IncorrectNanoTDFECDSASignatureSize: If the signature buffer size is invalid + """ if len(ecdsa_signature_value) != (2 * key_size) + 2: raise IncorrectNanoTDFECDSASignatureSize( @@ -72,8 +69,7 @@ def from_bytes( return struct_obj def as_bytes(self) -> bytes: - """ - Convert the signature structure to bytes. + """Convert the signature structure to bytes. Raises ValueError if r_value or s_value is None. """ if self.r_value is None or self.s_value is None: diff --git a/src/otdf_python/nanotdf_type.py b/src/otdf_python/nanotdf_type.py index 4ce112e..97d67e9 100644 --- a/src/otdf_python/nanotdf_type.py +++ b/src/otdf_python/nanotdf_type.py @@ -1,7 +1,11 @@ +"""NanoTDF type enumeration.""" + from enum import Enum class ECCurve(Enum): + """Elliptic curve enumeration for NanoTDF.""" + SECP256R1 = "secp256r1" SECP384R1 = "secp384r1" SECP521R1 = "secp384r1" @@ -12,11 +16,15 @@ def __str__(self): class Protocol(Enum): + """Protocol enumeration for KAS communication.""" + HTTP = "HTTP" HTTPS = "HTTPS" class IdentifierType(Enum): + """Identifier type enumeration for NanoTDF.""" + NONE = 0 TWO_BYTES = 2 EIGHT_BYTES = 8 @@ -27,6 +35,8 @@ def get_length(self): class PolicyType(Enum): + """Policy type enumeration for NanoTDF.""" + REMOTE_POLICY = 0 EMBEDDED_POLICY_PLAIN_TEXT = 1 EMBEDDED_POLICY_ENCRYPTED = 2 @@ -34,6 +44,8 @@ class PolicyType(Enum): class Cipher(Enum): + """Cipher enumeration for NanoTDF encryption.""" + AES_256_GCM_64_TAG = 0 AES_256_GCM_96_TAG = 1 AES_256_GCM_104_TAG = 2 diff --git a/src/otdf_python/policy_binding_serializer.py b/src/otdf_python/policy_binding_serializer.py index 72e3849..2b1c05e 100644 --- a/src/otdf_python/policy_binding_serializer.py +++ b/src/otdf_python/policy_binding_serializer.py @@ -1,21 +1,23 @@ +"""Policy binding serialization for HMAC calculation.""" + from typing import Any class PolicyBinding: - """ - Represents a policy binding in the TDF manifest. + """Represents a policy binding in the TDF manifest. + This is a placeholder implementation as the complete details of the PolicyBinding class aren't provided in the code snippets. """ def __init__(self, **kwargs): + """Initialize policy binding from kwargs.""" for key, value in kwargs.items(): setattr(self, key, value) class PolicyBindingSerializer: - """ - Handles serialization and deserialization of policy bindings. + """Handles serialization and deserialization of policy bindings. This class provides static methods to convert between JSON representations and PolicyBinding objects. """ diff --git a/src/otdf_python/policy_info.py b/src/otdf_python/policy_info.py index 5278c04..5c6105e 100644 --- a/src/otdf_python/policy_info.py +++ b/src/otdf_python/policy_info.py @@ -1,9 +1,15 @@ +"""Policy information handling for NanoTDF.""" + + class PolicyInfo: + """Policy information.""" + def __init__( self, policy_type: int = 0, body: bytes | None = None, ): + """Initialize policy information.""" self.policy_type = policy_type self.body = body diff --git a/src/otdf_python/policy_object.py b/src/otdf_python/policy_object.py index 83baa78..fe06129 100644 --- a/src/otdf_python/policy_object.py +++ b/src/otdf_python/policy_object.py @@ -1,8 +1,12 @@ +"""Policy object dataclasses for OpenTDF.""" + from dataclasses import dataclass @dataclass class AttributeObject: + """An attribute object.""" + attribute: str display_name: str | None = None is_default: bool = False @@ -12,11 +16,15 @@ class AttributeObject: @dataclass class PolicyBody: + """A policy body.""" + data_attributes: list[AttributeObject] dissem: list[str] @dataclass class PolicyObject: + """A policy object.""" + uuid: str body: PolicyBody diff --git a/src/otdf_python/policy_stub.py b/src/otdf_python/policy_stub.py index 8001149..9236ef0 100644 --- a/src/otdf_python/policy_stub.py +++ b/src/otdf_python/policy_stub.py @@ -1,2 +1,4 @@ +"""Policy UUID constants for OpenTDF.""" + # TODO: Replace this with a proper Policy UUID values NULL_POLICY_UUID: str = "00000000-0000-0000-0000-000000000000" diff --git a/src/otdf_python/resource_locator.py b/src/otdf_python/resource_locator.py index fd2f739..22e8b3e 100644 --- a/src/otdf_python/resource_locator.py +++ b/src/otdf_python/resource_locator.py @@ -1,15 +1,19 @@ +"""NanoTDF resource locator handling.""" + + class ResourceLocator: - """ - NanoTDF Resource Locator per the spec: - https://github.com/opentdf/spec/blob/main/schema/nanotdf/README.md + """Represent NanoTDF Resource Locator per specification. + + See https://github.com/opentdf/spec/blob/main/schema/nanotdf/README.md Format: - - Byte 0: Protocol Enum (bits 0-3) + Identifier Length (bits 4-7) - - Protocol: 0x0=HTTP, 0x1=HTTPS, 0xF=Shared Resource Directory - - Identifier: 0x0=None, 0x1=2 bytes, 0x2=8 bytes, 0x3=32 bytes - - Byte 1: Body Length (1-255 bytes) - - Bytes 2-N: Body (URL path) - - Bytes N+1-M: Identifier (optional, 0/2/8/32 bytes) + - Byte 0: Protocol Enum (bits 0-3) + Identifier Length (bits 4-7) + - Protocol: 0x0=HTTP, 0x1=HTTPS, 0xF=Shared Resource Directory + - Identifier: 0x0=None, 0x1=2 bytes, 0x2=8 bytes, 0x3=32 bytes + - Byte 1: Body Length (1-255 bytes) + - Bytes 2-N: Body (URL path) + - Bytes N+1-M: Identifier (optional, 0/2/8/32 bytes) + """ # Protocol enum values @@ -24,6 +28,13 @@ class ResourceLocator: IDENTIFIER_32_BYTES = 0x3 def __init__(self, resource_url: str | None = None, identifier: str | None = None): + """Initialize resource locator. + + Args: + resource_url: URL of the resource + identifier: Optional identifier for the resource + + """ self.resource_url = resource_url or "" self.identifier = identifier or "" @@ -71,8 +82,7 @@ def _get_identifier_bytes(self): raise ValueError(f"Identifier too long: {id_len} bytes (max 32)") def to_bytes(self): - """ - Convert to NanoTDF Resource Locator format per spec. + """Convert to NanoTDF Resource Locator format per spec. Format: - Byte 0: Protocol Enum (bits 0-3) + Identifier Length (bits 4-7) @@ -106,8 +116,7 @@ def write_into_buffer(self, buffer: bytearray, offset: int = 0) -> int: @staticmethod def from_bytes_with_size(buffer: bytes): # noqa: C901 - """ - Parse NanoTDF Resource Locator from bytes per spec. + """Parse NanoTDF Resource Locator from bytes per spec. Format: - Byte 0: Protocol Enum (bits 0-3) + Identifier Length (bits 4-7) diff --git a/src/otdf_python/sdk.py b/src/otdf_python/sdk.py index 87bdba0..3666f69 100644 --- a/src/otdf_python/sdk.py +++ b/src/otdf_python/sdk.py @@ -1,6 +1,4 @@ -""" -Python port of the main SDK class for OpenTDF platform interaction. -""" +"""The main SDK class for OpenTDF platform interaction.""" from contextlib import AbstractContextManager from io import BytesIO @@ -13,13 +11,10 @@ class KAS(AbstractContextManager): - """ - KAS (Key Access Service) interface to define methods related to key access and management. - """ + """KAS (Key Access Service) interface to define methods related to key access and management.""" def get_public_key(self, kas_info: Any) -> Any: - """ - Retrieves the public key from the KAS for RSA operations. + """Retrieve the public key from KAS for RSA operations. If the public key is cached, returns the cached value. Otherwise, makes a request to the KAS. @@ -31,6 +26,7 @@ def get_public_key(self, kas_info: Any) -> Any: Raises: SDKException: If there's an error retrieving the public key + """ # Delegate to the underlying KAS client which handles authentication properly return self._kas_client.get_public_key(kas_info) @@ -42,14 +38,14 @@ def __init__( sdk_ssl_verify=True, use_plaintext=False, ): - """ - Initialize the KAS client + """Initialize the KAS client. Args: platform_url: URL of the platform token_source: Function that returns an authentication token sdk_ssl_verify: Whether to verify SSL certificates use_plaintext: Whether to use plaintext HTTP connections instead of HTTPS + """ from .kas_client import KASClient @@ -64,8 +60,7 @@ def __init__( self._use_plaintext = use_plaintext def get_ec_public_key(self, kas_info: Any, curve: Any) -> Any: - """ - Retrieves the EC public key from the KAS. + """Retrieve the EC public key from KAS. Args: kas_info: KASInfo object containing the URL @@ -73,6 +68,7 @@ def get_ec_public_key(self, kas_info: Any, curve: Any) -> Any: Returns: Updated KASInfo object with KID and PublicKey populated + """ # Set algorithm to "ec:" from copy import copy @@ -82,8 +78,7 @@ def get_ec_public_key(self, kas_info: Any, curve: Any) -> Any: return self.get_public_key(kas_info_copy) def unwrap(self, key_access: Any, policy: str, session_key_type: Any) -> bytes: - """ - Unwraps the key using the KAS. + """Unwraps the key using the KAS. Args: key_access: KeyAccess object containing the wrapped key @@ -92,6 +87,7 @@ def unwrap(self, key_access: Any, policy: str, session_key_type: Any) -> bytes: Returns: Unwrapped key as bytes + """ return self._kas_client.unwrap(key_access, policy, session_key_type) @@ -104,8 +100,7 @@ def unwrap_nanotdf( kas_private_key: str | None = None, mock: bool = False, ) -> bytes: - """ - Unwraps the NanoTDF key using the KAS. If mock=True, performs local unwrap using the private key (for tests). + """Unwraps the NanoTDF key using the KAS. If mock=True, performs local unwrap using the private key (for tests). Args: curve: EC curve used @@ -117,6 +112,7 @@ def unwrap_nanotdf( Returns: Unwrapped key as bytes + """ if mock and wrapped_key and kas_private_key: from .asym_crypto import AsymDecryption @@ -128,16 +124,16 @@ def unwrap_nanotdf( raise NotImplementedError("KAS unwrap_nanotdf not implemented.") def get_key_cache(self) -> Any: - """ - Returns the KAS key cache. + """Return the KAS key cache. Returns: The KAS key cache object + """ return self._kas_client.get_key_cache() def close(self): - """Closes resources associated with the KAS interface""" + """Close resources associated with KAS interface.""" pass def __exit__(self, exc_type, exc_val, exc_tb): @@ -145,16 +141,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): class SDK(AbstractContextManager): + """SDK for OpenTDF platform interaction.""" + def new_tdf_config( self, attributes: list[str] | None = None, kas_info_list: list[KASInfo] | None = None, **kwargs, ) -> TDFConfig: - """ - Create a TDFConfig with default kas_info_list from the SDK's platform_url. - """ - + """Create a TDFConfig with default kas_info_list from the SDK's platform_url.""" if self.platform_url is None: raise SDKException("Cannot create TDFConfig: SDK platform_url is not set.") @@ -214,19 +209,16 @@ def new_tdf_config( """ class Services(AbstractContextManager): - """ - The Services interface provides access to various platform service clients and KAS. - """ + """The Services interface provides access to various platform service clients and KAS.""" def kas(self) -> KAS: - """ - Returns the KAS client for key access operations. + """Return the KAS client for key access operations. This should be implemented to return an instance of KAS. """ raise NotImplementedError def close(self): - """Closes resources associated with the services""" + """Close resources associated with the services.""" pass def __exit__(self, exc_type, exc_val, exc_tb): @@ -239,14 +231,14 @@ def __init__( ssl_verify: bool = True, use_plaintext: bool = False, ): - """ - Initializes a new SDK instance. + """Initialize a new SDK instance. Args: services: The services interface implementation platform_url: Optional platform base URL ssl_verify: Whether to verify SSL certificates (default: True) use_plaintext: Whether to use HTTP instead of HTTPS (default: False) + """ self.services = services self.platform_url = platform_url @@ -254,20 +246,20 @@ def __init__( self._use_plaintext = use_plaintext def __exit__(self, exc_type, exc_val, exc_tb): - """Clean up resources when exiting context manager""" + """Clean up resources when exiting context manager.""" self.close() def close(self): - """Close the SDK and release resources""" + """Close the SDK and release resources.""" if hasattr(self.services, "close"): self.services.close() def get_services(self) -> "SDK.Services": - """Returns the services interface""" + """Return the services interface.""" return self.services def get_platform_url(self) -> str | None: - """Returns the platform URL if set""" + """Return the platform URL if set.""" return self.platform_url def load_tdf( @@ -275,8 +267,7 @@ def load_tdf( tdf_data: bytes | BinaryIO | BytesIO, config: TDFReaderConfig | None = None, ) -> TDFReader: - """ - Loads a TDF from the provided data, optionally according to the config. + """Load a TDF from the provided data, optionally according to config. Args: tdf_data: The TDF data as bytes, file object, or BytesIO @@ -287,6 +278,7 @@ def load_tdf( Raises: SDKException: If there's an error loading the TDF + """ tdf = TDF(self.services) if config is None: @@ -300,8 +292,7 @@ def create_tdf( config, output_stream: BinaryIO | None = None, ): - """ - Creates a TDF with the provided payload. + """Create a TDF with the provided payload. Args: payload: The payload data as bytes, file object, or BytesIO @@ -313,6 +304,7 @@ def create_tdf( Raises: SDKException: If there's an error creating the TDF + """ tdf = TDF(self.services) return tdf.create_tdf(payload, config, output_stream) @@ -320,8 +312,7 @@ def create_tdf( def create_nano_tdf( self, payload: bytes | BytesIO, output_stream: BinaryIO, config: "NanoTDFConfig" ) -> int: - """ - Creates a NanoTDF with the provided payload. + """Create a NanoTDF with the provided payload. Args: payload: The payload data as bytes or BytesIO @@ -333,6 +324,7 @@ def create_nano_tdf( Raises: SDKException: If there's an error creating the NanoTDF + """ nano_tdf = NanoTDF(self.services) return nano_tdf.create_nano_tdf(payload, output_stream, config) @@ -343,8 +335,7 @@ def read_nano_tdf( output_stream: BinaryIO, config: NanoTDFConfig, ) -> None: - """ - Reads a NanoTDF and writes the payload to the output stream. + """Read a NanoTDF and write the payload to the output stream. Args: nano_tdf_data: The NanoTDF data as bytes or BytesIO @@ -353,20 +344,21 @@ def read_nano_tdf( Raises: SDKException: If there's an error reading the NanoTDF + """ nano_tdf = NanoTDF(self.services) nano_tdf.read_nano_tdf(nano_tdf_data, output_stream, config) @staticmethod def is_tdf(data: bytes | BinaryIO) -> bool: - """ - Checks if the provided data is a TDF. + """Check if the provided data is a TDF. Args: data: The data to check Returns: bool: True if the data is a TDF, False otherwise + """ import zipfile from io import BytesIO @@ -383,54 +375,54 @@ def is_tdf(data: bytes | BinaryIO) -> bool: # Exception classes - SDK-specific exceptions that can occur during operations class SplitKeyException(SDKException): - """Thrown when the SDK encounters an error related to split key operations""" + """Throw when SDK encounters error related to split key operations.""" pass class DataSizeNotSupported(SDKException): - """Thrown when the user attempts to create a TDF with a size larger than the maximum size""" + """Throw when user attempts to create TDF larger than maximum size.""" pass class KasInfoMissing(SDKException): - """Thrown during TDF creation when no KAS information is present""" + """Throw during TDF creation when no KAS information is present.""" pass class KasPublicKeyMissing(SDKException): - """Thrown during encryption when the SDK cannot retrieve the public key for a KAS""" + """Throw during encryption when SDK cannot retrieve public key for KAS.""" pass class TamperException(SDKException): - """Base class for exceptions related to signature mismatches""" + """Base class for exceptions related to signature mismatches.""" def __init__(self, error_message: str): + """Initialize tamper exception.""" super().__init__(f"[tamper detected] {error_message}") class RootSignatureValidationException(TamperException): - """Thrown when the root signature validation fails""" - - pass + """Throw when root signature validation fails.""" class SegmentSignatureMismatch(TamperException): - """Thrown when a segment signature does not match the expected value""" + """Throw when segment signature does not match expected value.""" pass class KasBadRequestException(SDKException): - """Thrown when the KAS returns a bad request response""" + """Throw when KAS returns bad request response.""" pass class KasAllowlistException(SDKException): - """Thrown when the KAS allowlist check fails""" + """Throw when KAS allowlist check fails.""" pass class AssertionException(SDKException): - """Thrown when an assertion validation fails""" + """Throw when an assertion validation fails.""" def __init__(self, error_message: str, assertion_id: str): + """Initialize exception.""" super().__init__(error_message) self.assertion_id = assertion_id diff --git a/src/otdf_python/sdk_builder.py b/src/otdf_python/sdk_builder.py index a1059db..f1c3499 100644 --- a/src/otdf_python/sdk_builder.py +++ b/src/otdf_python/sdk_builder.py @@ -1,5 +1,5 @@ -""" -Python port of the SDKBuilder class for OpenTDF platform interaction. +"""SDKBuilder class for OpenTDF platform interaction. + Provides methods to configure and build SDK instances. """ @@ -19,6 +19,8 @@ @dataclass class OAuthConfig: + """OAuth configuration.""" + client_id: str client_secret: str grant_type: str = "client_credentials" @@ -28,9 +30,7 @@ class OAuthConfig: class SDKBuilder: - """ - A builder class for creating instances of the SDK class. - """ + """A builder class for creating instances of the SDK class.""" PLATFORM_ISSUER = "platform_issuer" @@ -38,6 +38,7 @@ class SDKBuilder: _platform_url = None def __init__(self): + """Initialize SDK builder.""" self.platform_endpoint: str | None = None self.issuer_endpoint: str | None = None self.oauth_config: OAuthConfig | None = None @@ -49,29 +50,33 @@ def __init__(self): @staticmethod def new_builder() -> "SDKBuilder": - """ - Creates a new SDKBuilder instance. + """Create a new SDKBuilder instance. + Returns: SDKBuilder: A new builder instance + """ return SDKBuilder() @staticmethod def get_platform_url() -> str | None: - """ - Gets the last set platform URL. + """Get the last set platform URL. + Returns: str | None: The platform URL or None if not set + """ return SDKBuilder._platform_url def ssl_context_from_directory(self, certs_dir_path: str) -> "SDKBuilder": - """ - Add SSL Context with trusted certs from certDirPath + """Add SSL context with trusted certs from certDirPath. + Args: certs_dir_path: Path to a directory containing .pem or .crt trusted certs + Returns: self: The builder instance for chaining + """ self.cert_paths = [] @@ -91,13 +96,14 @@ def ssl_context_from_directory(self, certs_dir_path: str) -> "SDKBuilder": return self def client_secret(self, client_id: str, client_secret: str) -> "SDKBuilder": - """ - Sets client credentials for OAuth 2.0 client_credentials grant. + """Set client credentials for OAuth 2.0 client_credentials grant. + Args: client_id: The OAuth client ID client_secret: The OAuth client secret Returns: self: The builder instance for chaining + """ self.oauth_config = OAuthConfig( client_id=client_id, client_secret=client_secret @@ -105,12 +111,13 @@ def client_secret(self, client_id: str, client_secret: str) -> "SDKBuilder": return self def set_platform_endpoint(self, endpoint: str) -> "SDKBuilder": - """ - Sets the OpenTDF platform endpoint URL. + """Set the OpenTDF platform endpoint URL. + Args: endpoint: The platform endpoint URL Returns: self: The builder instance for chaining + """ # Normalize the endpoint URL if endpoint and not ( @@ -127,12 +134,13 @@ def set_platform_endpoint(self, endpoint: str) -> "SDKBuilder": return self def set_issuer_endpoint(self, issuer: str) -> "SDKBuilder": - """ - Sets the OpenID Connect issuer endpoint URL. + """Set the OpenID Connect issuer endpoint URL. + Args: issuer: The issuer endpoint URL Returns: self: The builder instance for chaining + """ # Normalize the issuer URL if issuer and not ( @@ -146,12 +154,13 @@ def set_issuer_endpoint(self, issuer: str) -> "SDKBuilder": def use_insecure_plaintext_connection( self, use_plaintext: bool = True ) -> "SDKBuilder": - """ - Configures whether to use plain text (HTTP) connection instead of HTTPS. + """Configure whether to use plain text (HTTP) instead of HTTPS. + Args: use_plaintext: Whether to use plain text connection Returns: self: The builder instance for chaining + """ self.use_plaintext = use_plaintext @@ -168,12 +177,13 @@ def use_insecure_plaintext_connection( return self def use_insecure_skip_verify(self, skip_verify: bool = True) -> "SDKBuilder": - """ - Configures whether to skip SSL verification. + """Configure whether to skip SSL verification. + Args: skip_verify: Whether to skip SSL verification Returns: self: The builder instance for chaining + """ self.insecure_skip_verify = skip_verify @@ -184,21 +194,23 @@ def use_insecure_skip_verify(self, skip_verify: bool = True) -> "SDKBuilder": return self def bearer_token(self, token: str) -> "SDKBuilder": - """ - Sets a bearer token to use for authorization. + """Set a bearer token to use for authorization. + Args: token: The bearer token Returns: self: The builder instance for chaining + """ self.auth_token = token return self def _discover_token_endpoint_from_platform(self) -> None: - """ - Discover token endpoint using OpenTDF platform configuration. + """Discover token endpoint using OpenTDF platform configuration. + Raises: AutoConfigureException: If discovery fails + """ if not self.platform_endpoint or not self.oauth_config: return @@ -232,12 +244,13 @@ def _discover_token_endpoint_from_platform(self) -> None: self._discover_token_endpoint_from_issuer(platform_issuer) def _discover_token_endpoint_from_issuer(self, issuer_url: str) -> None: - """ - Discover token endpoint using OIDC discovery from issuer. + """Discover token endpoint using OIDC discovery from issuer. + Args: issuer_url: The issuer URL to use for discovery Raises: AutoConfigureException: If discovery fails + """ if not self.oauth_config: return @@ -260,10 +273,11 @@ def _discover_token_endpoint_from_issuer(self, issuer_url: str) -> None: ) def _discover_token_endpoint(self) -> None: - """ - Discover the token endpoint using available configuration. + """Discover the token endpoint using available configuration. + Raises: AutoConfigureException: If discovery fails + """ # Try platform endpoint first if self.platform_endpoint: @@ -297,12 +311,13 @@ def _discover_token_endpoint(self) -> None: ) def _get_token_from_client_credentials(self) -> str: - """ - Obtains an OAuth token using client credentials. + """Obtain an OAuth token using client credentials. + Returns: - str: The access token + str: The OAuth access token Raises: AutoConfigureException: If token acquisition fails + """ if not self.oauth_config: raise AutoConfigureException("OAuth configuration is not set") @@ -346,12 +361,13 @@ def _get_token_from_client_credentials(self) -> str: ) from e def _create_services(self) -> SDK.Services: - """ - Creates service client instances. + """Create service client instances. + Returns: SDK.Services: The service client instances Raises: AutoConfigureException: If service creation fails + """ # For now, return a simple implementation of Services # In a real implementation, this would create actual service clients @@ -366,9 +382,7 @@ def __init__(self, builder_instance): self._builder = builder_instance def kas(self) -> KAS: - """ - Returns the KAS interface with SSL verification settings. - """ + """Return the KAS interface with SSL verification settings.""" platform_url = SDKBuilder.get_platform_url() # Create a token source function that can refresh tokens @@ -396,12 +410,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): return ServicesImpl(self) def build(self) -> SDK: - """ - Builds and returns an SDK instance with the configured properties. + """Build and return an SDK instance with configured properties. + Returns: SDK: The configured SDK instance Raises: AutoConfigureException: If the build fails + """ if not self.platform_endpoint: raise AutoConfigureException("Platform endpoint is not set") diff --git a/src/otdf_python/sdk_exceptions.py b/src/otdf_python/sdk_exceptions.py index cc8fbee..f8bce6e 100644 --- a/src/otdf_python/sdk_exceptions.py +++ b/src/otdf_python/sdk_exceptions.py @@ -1,16 +1,26 @@ +"""SDK-specific exception classes.""" + + class SDKException(Exception): + """Base SDK exception class.""" + def __init__(self, message, reason=None): + """Initialize exception.""" super().__init__(message) self.reason = reason class AutoConfigureException(SDKException): + """Exception for SDK auto-configuration failures.""" + def __init__(self, message, cause=None): + """Initialize exception.""" super().__init__(message, cause) class KASBadRequestException(SDKException): - """Thrown when the KAS returns a bad request response or other client request errors.""" + """Exception for KAS bad request or client errors.""" def __init__(self, message): + """Initialize exception.""" super().__init__(message) diff --git a/src/otdf_python/symmetric_and_payload_config.py b/src/otdf_python/symmetric_and_payload_config.py index f7d6c07..e69b195 100644 --- a/src/otdf_python/symmetric_and_payload_config.py +++ b/src/otdf_python/symmetric_and_payload_config.py @@ -1,10 +1,16 @@ +"""Symmetric encryption and payload configuration.""" + + class SymmetricAndPayloadConfig: + """Symmetric and payload configuration.""" + def __init__( self, cipher_type: int = 0, signature_ecc_mode: int = 0, has_signature: bool = True, ): + """Initialize symmetric and payload configuration.""" self.cipher_type = cipher_type self.signature_ecc_mode = signature_ecc_mode self.has_signature = has_signature diff --git a/src/otdf_python/tdf.py b/src/otdf_python/tdf.py index 3951ecb..dafa246 100644 --- a/src/otdf_python/tdf.py +++ b/src/otdf_python/tdf.py @@ -1,3 +1,5 @@ +"""TDF reader and writer functionality for OpenTDF platform.""" + import base64 import hashlib import hmac @@ -31,17 +33,23 @@ @dataclass class TDFReader: + """Container for TDF payload and manifest after reading.""" + payload: bytes manifest: Manifest @dataclass class TDFReaderConfig: + """Configuration for TDF reader operations.""" + kas_private_key: str | None = None attributes: list[str] | None = None class TDF: + """TDF reader and writer for handling TDF encryption and decryption.""" + MAX_TDF_INPUT_SIZE = 68719476736 GCM_KEY_SIZE = 32 GCM_IV_SIZE = 12 @@ -53,6 +61,13 @@ class TDF: GLOBAL_KEY_SALT = b"TDF-Session-Key" def __init__(self, services=None, maximum_size: int | None = None): + """Initialize TDF reader/writer. + + Args: + services: SDK services for KAS operations + maximum_size: Maximum size allowed for TDF operations + + """ self.services = services self.maximum_size = maximum_size or self.MAX_TDF_INPUT_SIZE @@ -162,7 +177,7 @@ def _build_policy_json(self, config: TDFConfig) -> str: return _json.dumps(policy, default=self._serialize_policy_object) def _serialize_policy_object(self, obj): - """Custom TDF serializer to convert to compatible JSON format.""" + """Serialize policy object to compatible JSON format.""" from otdf_python.policy_object import AttributeObject, PolicyBody if isinstance(obj, PolicyBody): @@ -185,9 +200,7 @@ def _serialize_policy_object(self, obj): return obj.__dict__ def _unwrap_key(self, key_access_objs, private_key_pem): - """ - Unwraps the key locally using a provided private key (used for testing) - """ + """Unwrap the key locally using provided private key (used for testing).""" from .asym_crypto import AsymDecryption key = None @@ -204,9 +217,7 @@ def _unwrap_key(self, key_access_objs, private_key_pem): return key def _unwrap_key_with_kas(self, key_access_objs, policy_b64) -> bytes: - """ - Unwraps the key using the KAS service (production method) - """ + """Unwrap the key using the KAS service (production method).""" # Get KAS client from services if not self.services: raise ValueError("SDK services required for KAS operations") @@ -281,6 +292,17 @@ def create_tdf( config: TDFConfig, output_stream: io.BytesIO | None = None, ): + """Create a TDF with the provided payload and configuration. + + Args: + payload: The payload data as bytes or BinaryIO + config: TDFConfig for encryption settings + output_stream: Optional output stream, creates new BytesIO if not provided + + Returns: + Tuple of (manifest, size, output_stream) + + """ if output_stream is None: output_stream = io.BytesIO() writer = TDFWriter(output_stream) @@ -380,6 +402,16 @@ def create_tdf( def load_tdf( self, tdf_data: bytes | io.BytesIO, config: TDFReaderConfig ) -> TDFReader: + """Load and decrypt a TDF from the provided data. + + Args: + tdf_data: The TDF data as bytes or BytesIO + config: TDFReaderConfig with optional private key for local unwrapping + + Returns: + TDFReader containing payload and manifest + + """ # Extract manifest, unwrap payload key using KAS client # Handle both bytes and BinaryIO input tdf_bytes_io = io.BytesIO(tdf_data) if isinstance(tdf_data, bytes) else tdf_data @@ -423,8 +455,13 @@ def load_tdf( def read_payload( self, tdf_bytes: bytes, config: dict, output_stream: BinaryIO ) -> None: - """ - Reads and verifies TDF segments, decrypts if needed, and writes the payload to output_stream. + """Read and verify TDF segments, decrypt if needed, and write the payload. + + Args: + tdf_bytes: The TDF data as bytes + config: Configuration dictionary for reading + output_stream: The output stream to write the payload to + """ import base64 import zipfile diff --git a/src/otdf_python/tdf_reader.py b/src/otdf_python/tdf_reader.py index a414f16..db884fd 100644 --- a/src/otdf_python/tdf_reader.py +++ b/src/otdf_python/tdf_reader.py @@ -1,6 +1,4 @@ -""" -TDFReader is responsible for reading and processing Trusted Data Format (TDF) files. -""" +"""TDFReader is responsible for reading and processing Trusted Data Format (TDF) files.""" from .manifest import Manifest from .policy_object import PolicyObject @@ -13,15 +11,13 @@ class TDFReader: - """ - TDFReader is responsible for reading and processing Trusted Data Format (TDF) files. + """TDFReader is responsible for reading and processing Trusted Data Format (TDF) files. The class initializes with a TDF file channel, extracts the manifest and payload entries, and provides methods to retrieve the manifest content, read payload bytes, and read policy objects. """ def __init__(self, tdf): - """ - Initialize a TDFReader with a TDF file channel. + """Initialize a TDFReader with a TDF file channel. Args: tdf: A file-like object containing the TDF data @@ -29,6 +25,7 @@ def __init__(self, tdf): Raises: SDKException: If there's an error reading the TDF ValueError: If the TDF doesn't contain a manifest or payload + """ try: self._zip_reader = ZipReader(tdf) @@ -48,14 +45,14 @@ def __init__(self, tdf): raise SDKException("Error initializing TDFReader") from e def manifest(self) -> str: - """ - Get the manifest content as a string. + """Get the manifest content as a string. Returns: The manifest content as a UTF-8 encoded string Raises: SDKException: If there's an error retrieving the manifest + """ try: manifest_data = self._zip_reader.read(self._manifest_name) @@ -64,8 +61,7 @@ def manifest(self) -> str: raise SDKException("Error retrieving manifest from zip file") from e def read_payload_bytes(self, buf: bytearray) -> int: - """ - Read bytes from the payload into a buffer. + """Read bytes from the payload into a buffer. Args: buf: A bytearray buffer to read into @@ -75,6 +71,7 @@ def read_payload_bytes(self, buf: bytearray) -> int: Raises: SDKException: If there's an error reading from the payload + """ try: # Read the entire payload @@ -89,14 +86,14 @@ def read_payload_bytes(self, buf: bytearray) -> int: raise SDKException("Error reading from payload in TDF") from e def read_policy_object(self) -> PolicyObject: - """ - Read the policy object from the manifest. + """Read the policy object from the manifest. Returns: The PolicyObject Raises: SDKException: If there's an error reading the policy object + """ try: manifest_text = self.manifest() diff --git a/src/otdf_python/tdf_writer.py b/src/otdf_python/tdf_writer.py index 6dcd7d5..bef7cfd 100644 --- a/src/otdf_python/tdf_writer.py +++ b/src/otdf_python/tdf_writer.py @@ -1,13 +1,18 @@ +"""TDF writer for creating encrypted TDF files.""" + import io from otdf_python.zip_writer import ZipWriter class TDFWriter: + """TDF file writer for creating encrypted TDF packages.""" + TDF_PAYLOAD_FILE_NAME = "0.payload" TDF_MANIFEST_FILE_NAME = "0.manifest.json" def __init__(self, out_stream: io.BytesIO | None = None): + """Initialize TDF writer.""" self._zip_writer = ZipWriter(out_stream) def append_manifest(self, manifest: str): diff --git a/src/otdf_python/token_source.py b/src/otdf_python/token_source.py index 0c60c3a..c5e5ca1 100644 --- a/src/otdf_python/token_source.py +++ b/src/otdf_python/token_source.py @@ -1,6 +1,4 @@ -""" -TokenSource: Handles OAuth2 token acquisition and caching. -""" +"""TokenSource: Handles OAuth2 token acquisition and caching.""" import time @@ -8,7 +6,10 @@ class TokenSource: + """OAuth2 token source for authentication.""" + def __init__(self, token_url, client_id, client_secret): + """Initialize token source.""" self.token_url = token_url self.client_id = client_id self.client_secret = client_secret diff --git a/src/otdf_python/version.py b/src/otdf_python/version.py index fd101ec..78b135f 100644 --- a/src/otdf_python/version.py +++ b/src/otdf_python/version.py @@ -1,9 +1,13 @@ +"""SDK version information.""" + import re from functools import total_ordering @total_ordering class Version: + """Semantic version representation.""" + SEMVER_PATTERN = re.compile( r"^(?P0|[1-9]\d*)\.(?P0|[1-9]\d*)\.(?P0|[1-9]\d*)(?P\D.*)?$" ) @@ -15,6 +19,7 @@ def __init__( patch=None, prerelease_and_metadata: str | None = None, ): + """Initialize semantic version.""" if minor is None and patch is None: # Parse from string m = self.SEMVER_PATTERN.match(semver_or_major) diff --git a/src/otdf_python/zip_reader.py b/src/otdf_python/zip_reader.py index 73f2b07..02a0d9a 100644 --- a/src/otdf_python/zip_reader.py +++ b/src/otdf_python/zip_reader.py @@ -1,3 +1,5 @@ +"""ZIP file reader for TDF operations.""" + import io import zipfile @@ -5,8 +7,13 @@ class ZipReader: + """ZIP file reader for reading TDF packages.""" + class Entry: + """ZIP file entry with data access.""" + def __init__(self, zipfile_obj, zipinfo): + """Initialize ZIP entry.""" self._zipfile = zipfile_obj self._zipinfo = zipinfo @@ -20,6 +27,7 @@ def get_data(self) -> bytes: raise InvalidZipException(f"Error reading entry data: {e}") from e def __init__(self, in_stream: io.BytesIO | bytes | None = None): + """Initialize ZIP reader.""" try: if isinstance(in_stream, bytes): in_stream = io.BytesIO(in_stream) diff --git a/src/otdf_python/zip_writer.py b/src/otdf_python/zip_writer.py index e548d97..3c9c625 100644 --- a/src/otdf_python/zip_writer.py +++ b/src/otdf_python/zip_writer.py @@ -1,10 +1,15 @@ +"""ZIP file writer for TDF operations.""" + import io import zipfile import zlib class FileInfo: + """ZIP file metadata information.""" + def __init__(self, name: str, crc: int, size: int, offset: int): + """Initialize file info.""" self.name = name self.crc = crc self.size = size @@ -12,7 +17,10 @@ def __init__(self, name: str, crc: int, size: int, offset: int): class ZipWriter: + """ZIP file writer for creating TDF packages.""" + def __init__(self, out_stream: io.BytesIO | None = None): + """Initialize ZIP writer.""" self.out_stream = out_stream or io.BytesIO() self.zipfile = zipfile.ZipFile( self.out_stream, mode="w", compression=zipfile.ZIP_STORED @@ -45,7 +53,10 @@ def get_file_infos(self) -> list[FileInfo]: class _TrackingWriter(io.RawIOBase): + """Internal ZIP stream writer with offset tracking.""" + def __init__(self, zip_writer: ZipWriter, name: str, offset: int): + """Initialize tracking writer.""" self._zip_writer = zip_writer self._name = name self._offset = offset diff --git a/tests/__init__.py b/tests/__init__.py index 3bcd7e8..f7bfff2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -# Empty file to make tests a package +"""Test suite for OpenTDF Python SDK.""" diff --git a/tests/config_pydantic.py b/tests/config_pydantic.py index 457f01c..10a5d3b 100644 --- a/tests/config_pydantic.py +++ b/tests/config_pydantic.py @@ -1,5 +1,4 @@ -""" -In this module, we are migrating to using `pydantic-settings`. +"""In this module, we are migrating to using `pydantic-settings`. Docs: https://docs.pydantic.dev/latest/concepts/pydantic_settings/ @@ -78,9 +77,7 @@ class ConfigureTdf(BaseSettings): class ConfigureTesting(BaseSettings): - """ - Used by integration tests (in particular for SSH and Kubernetes access). - """ + """Used by integration tests (in particular for SSH and Kubernetes access).""" model_config = SettingsConfigDict( # env_prefix="common_", diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 8cf0d77..8471981 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,6 +1,4 @@ -""" -Shared fixtures and utilities for integration tests. -""" +"""Shared fixtures and utilities for integration tests.""" import json import logging diff --git a/tests/integration/otdfctl_only/test_otdfctl_generated_fixtures.py b/tests/integration/otdfctl_only/test_otdfctl_generated_fixtures.py index 112d804..064639d 100644 --- a/tests/integration/otdfctl_only/test_otdfctl_generated_fixtures.py +++ b/tests/integration/otdfctl_only/test_otdfctl_generated_fixtures.py @@ -6,7 +6,6 @@ @pytest.mark.integration def test_test_data_directory_structure(tdf_v4_2_2_files, tdf_v4_3_1_files): """Test that the TDF files are properly generated by fixtures.""" - # Check v4.2.2 TDF files exist and are valid expected_v4_2_2_files = ["text", "binary", "with_attributes"] for file_key in expected_v4_2_2_files: @@ -44,7 +43,6 @@ def test_test_data_directory_structure(tdf_v4_2_2_files, tdf_v4_3_1_files): @pytest.mark.integration def test_sample_file_contents(sample_input_files): """Test that sample input files have expected content.""" - # Check text file has content text_file = sample_input_files["text"] assert text_file.exists(), f"Text file should exist: {text_file}" diff --git a/tests/integration/otdfctl_to_python/test_cli_comparison.py b/tests/integration/otdfctl_to_python/test_cli_comparison.py index cc7fd54..9916f4c 100644 --- a/tests/integration/otdfctl_to_python/test_cli_comparison.py +++ b/tests/integration/otdfctl_to_python/test_cli_comparison.py @@ -1,6 +1,4 @@ -""" -Test CLI functionality -""" +"""Test CLI functionality""" import tempfile from pathlib import Path @@ -24,7 +22,6 @@ def test_otdfctl_encrypt_python_decrypt( collect_server_logs, temp_credentials_file, project_root ): """Integration test that uses otdfctl for encryption and the Python CLI for decryption""" - # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -105,7 +102,6 @@ def test_otdfctl_encrypt_python_decrypt( @pytest.mark.integration def test_otdfctl_encrypt_otdfctl_decrypt(collect_server_logs, temp_credentials_file): """Integration test that uses otdfctl for both encryption and decryption to verify roundtrip functionality""" - # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) diff --git a/tests/integration/otdfctl_to_python/test_cli_decrypt.py b/tests/integration/otdfctl_to_python/test_cli_decrypt.py index dc61123..eaea2b5 100644 --- a/tests/integration/otdfctl_to_python/test_cli_decrypt.py +++ b/tests/integration/otdfctl_to_python/test_cli_decrypt.py @@ -1,6 +1,4 @@ -""" -Tests using target mode fixtures, for CLI integration testing. -""" +"""Tests using target mode fixtures, for CLI integration testing.""" import logging import subprocess @@ -21,10 +19,7 @@ def test_cli_decrypt_v4_2_2_vs_v4_3_1( all_target_mode_tdf_files, temp_credentials_file, collect_server_logs, project_root ): - """ - Test Python CLI decrypt with various TDF versions created by otdfctl. - """ - + """Test Python CLI decrypt with various TDF versions created by otdfctl.""" v4_2_2_files = all_target_mode_tdf_files["v4.2.2"] v4_3_1_files = all_target_mode_tdf_files["v4.3.1"] @@ -88,10 +83,7 @@ def test_cli_decrypt_different_file_types( project_root, known_target_modes, ): - """ - Test CLI decrypt with different file types. - """ - + """Test CLI decrypt with different file types.""" assert "v4.2.2" in all_target_mode_tdf_files assert "v4.3.1" in all_target_mode_tdf_files @@ -152,8 +144,7 @@ def test_cli_decrypt_different_file_types( def _run_cli_decrypt( tdf_path: Path, creds_file: Path, cwd: Path, collect_server_logs ) -> Path | None: - """ - Helper function to run Python CLI decrypt command and return the output file path. + """Helper function to run Python CLI decrypt command and return the output file path. Returns the Path to the decrypted output file if successful, None if failed. """ diff --git a/tests/integration/otdfctl_to_python/test_cli_inspect.py b/tests/integration/otdfctl_to_python/test_cli_inspect.py index 1ba39cf..c406355 100644 --- a/tests/integration/otdfctl_to_python/test_cli_inspect.py +++ b/tests/integration/otdfctl_to_python/test_cli_inspect.py @@ -1,6 +1,4 @@ -""" -Tests using target mode fixtures, for CLI integration testing. -""" +"""Tests using target mode fixtures, for CLI integration testing.""" import logging @@ -15,10 +13,7 @@ def test_cli_inspect_v4_2_2_vs_v4_3_1( all_target_mode_tdf_files, temp_credentials_file, project_root ): - """ - Test Python CLI inspect with various TDF versions created by otdfctl. - """ - + """Test Python CLI inspect with various TDF versions created by otdfctl.""" v4_2_2_files = all_target_mode_tdf_files["v4.2.2"] v4_3_1_files = all_target_mode_tdf_files["v4.3.1"] @@ -83,9 +78,7 @@ def test_cli_inspect_v4_2_2_vs_v4_3_1( def test_cli_inspect_different_file_types( all_target_mode_tdf_files, temp_credentials_file, project_root, known_target_modes ): - """ - Test CLI inspect with different file types. - """ + """Test CLI inspect with different file types.""" assert "v4.2.2" in all_target_mode_tdf_files assert "v4.3.1" in all_target_mode_tdf_files diff --git a/tests/integration/otdfctl_to_python/test_nanotdf_cli_comparison.py b/tests/integration/otdfctl_to_python/test_nanotdf_cli_comparison.py index ce5dac1..7560087 100644 --- a/tests/integration/otdfctl_to_python/test_nanotdf_cli_comparison.py +++ b/tests/integration/otdfctl_to_python/test_nanotdf_cli_comparison.py @@ -1,5 +1,4 @@ -""" -Integration tests for NanoTDF using otdfctl and Python CLI interoperability. +"""Integration tests for NanoTDF using otdfctl and Python CLI interoperability. These tests verify that: 1. otdfctl can encrypt to NanoTDF and Python can decrypt @@ -31,7 +30,6 @@ def test_otdfctl_encrypt_nano_python_decrypt( collect_server_logs, temp_credentials_file, project_root ): """Test otdfctl encrypt with --tdf-type nano and Python CLI decrypt.""" - with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -103,7 +101,6 @@ def test_python_encrypt_nano_otdfctl_decrypt( collect_server_logs, temp_credentials_file, project_root ): """Test Python CLI encrypt with --container-type nano and otdfctl decrypt.""" - with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -176,11 +173,10 @@ def test_python_encrypt_nano_otdfctl_decrypt( def test_nanotdf_roundtrip_comparison( collect_server_logs, temp_credentials_file, project_root ): - """ - Compare NanoTDF files created by otdfctl and Python CLI. + """Compare NanoTDF files created by otdfctl and Python CLI. + Tests both tools' roundtrip encryption/decryption. """ - with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -303,7 +299,6 @@ def test_nanotdf_with_attributes( collect_server_logs, temp_credentials_file, project_root ): """Test NanoTDF encryption/decryption with attributes.""" - with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) diff --git a/tests/integration/otdfctl_to_python/test_python_nanotdf_only.py b/tests/integration/otdfctl_to_python/test_python_nanotdf_only.py index 26cd61d..c04b59e 100644 --- a/tests/integration/otdfctl_to_python/test_python_nanotdf_only.py +++ b/tests/integration/otdfctl_to_python/test_python_nanotdf_only.py @@ -1,5 +1,5 @@ -""" -Simple NanoTDF integration test focusing on Python CLI only. +"""Simple NanoTDF integration test focusing on Python CLI only. + This tests the Python implementation without otdfctl dependency. """ @@ -23,7 +23,6 @@ def test_python_nanotdf_roundtrip( collect_server_logs, temp_credentials_file, project_root ): """Test Python CLI NanoTDF encryption and decryption roundtrip.""" - with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) diff --git a/tests/integration/otdfctl_to_python/test_tdf_reader_integration.py b/tests/integration/otdfctl_to_python/test_tdf_reader_integration.py index 457b25f..7cd5f3a 100644 --- a/tests/integration/otdfctl_to_python/test_tdf_reader_integration.py +++ b/tests/integration/otdfctl_to_python/test_tdf_reader_integration.py @@ -1,6 +1,4 @@ -""" -Integration Tests for TDFReader. -""" +"""Integration Tests for TDFReader.""" import io import json @@ -25,7 +23,6 @@ def test_read_otdfctl_created_tdf_structure( self, temp_credentials_file, collect_server_logs ): """Test that TDFReader can parse the structure of files created by otdfctl.""" - # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -108,7 +105,6 @@ def test_read_otdfctl_tdf_with_attributes( self, temp_credentials_file, collect_server_logs ): """Test reading TDF files created by otdfctl with data attributes.""" - # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -181,7 +177,6 @@ def test_read_multiple_otdfctl_files( self, temp_credentials_file, collect_server_logs ): """Test reading multiple TDF files of different types created by otdfctl.""" - # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) diff --git a/tests/integration/python_only/test_kas_client_integration.py b/tests/integration/python_only/test_kas_client_integration.py index 97fc723..e58fb02 100644 --- a/tests/integration/python_only/test_kas_client_integration.py +++ b/tests/integration/python_only/test_kas_client_integration.py @@ -1,6 +1,4 @@ -""" -Integration tests for KASClient. -""" +"""Integration tests for KASClient.""" import pytest diff --git a/tests/integration/support_sdk.py b/tests/integration/support_sdk.py index 0d93ba3..9fe11f9 100644 --- a/tests/integration/support_sdk.py +++ b/tests/integration/support_sdk.py @@ -51,8 +51,7 @@ def get_user_access_token( pe_username, pe_password, ): - """ - When using this function, ensure that: + """When using this function, ensure that: 1. The client has "fine-grained access control" enabled (in the Advanced tab for the client in Keycloak). 2. The client is allowed to use "Direct access grants" (in the Settings tab for the client in Keycloak). diff --git a/tests/integration/test_cli_integration.py b/tests/integration/test_cli_integration.py index f830e69..d03c763 100644 --- a/tests/integration/test_cli_integration.py +++ b/tests/integration/test_cli_integration.py @@ -1,6 +1,4 @@ -""" -Integration Test CLI functionality -""" +"""Integration Test CLI functionality""" import tempfile from pathlib import Path @@ -24,10 +22,7 @@ def test_cli_decrypt_otdfctl_tdf( collect_server_logs, temp_credentials_file, project_root ): - """ - Test that the Python CLI can successfully decrypt TDF files created by otdfctl. - """ - + """Test that the Python CLI can successfully decrypt TDF files created by otdfctl.""" # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -88,10 +83,7 @@ def test_cli_decrypt_otdfctl_tdf( def test_otdfctl_decrypt_comparison( collect_server_logs, temp_credentials_file, project_root ): - """ - Test comparative decryption between otdfctl and Python CLI on the same TDF. - """ - + """Test comparative decryption between otdfctl and Python CLI on the same TDF.""" # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -170,10 +162,7 @@ def test_otdfctl_decrypt_comparison( @pytest.mark.integration def test_otdfctl_encrypt_decrypt_roundtrip(collect_server_logs, temp_credentials_file): - """ - Test complete encrypt-decrypt roundtrip using otdfctl to verify functionality. - """ - + """Test complete encrypt-decrypt roundtrip using otdfctl to verify functionality.""" # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -250,7 +239,6 @@ def test_cli_encrypt_integration( collect_server_logs, temp_credentials_file, project_root ): """Integration test comparing our CLI with otdfctl""" - # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) diff --git a/tests/integration/test_cli_tdf_validation.py b/tests/integration/test_cli_tdf_validation.py index 54a180a..91d65ec 100644 --- a/tests/integration/test_cli_tdf_validation.py +++ b/tests/integration/test_cli_tdf_validation.py @@ -1,6 +1,4 @@ -""" -Test CLI encryption functionality and TDF validation -""" +"""Test CLI encryption functionality and TDF validation""" import json import tempfile @@ -285,7 +283,6 @@ def _run_python_cli_decrypt( @pytest.mark.integration def test_otdfctl_encrypt_with_validation(collect_server_logs, temp_credentials_file): """Integration test that uses otdfctl for encryption and validates the TDF thoroughly.""" - # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -331,7 +328,6 @@ def test_otdfctl_encrypt_with_validation(collect_server_logs, temp_credentials_f @pytest.mark.integration def test_python_encrypt(collect_server_logs, temp_credentials_file, project_root): """Integration test that uses Python CLI for encryption only and verifies the TDF can be inspected""" - # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -380,7 +376,6 @@ def test_cross_tool_compatibility( collect_server_logs, temp_credentials_file, project_root ): """Test that TDFs created by one tool can be decrypted by the other.""" - # Create temporary directory for work with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -453,7 +448,6 @@ def test_different_content_types( collect_server_logs, temp_credentials_file, project_root ): """Test encryption/decryption with different types of content.""" - test_cases = [ ("short.txt", "x"), # Single character ("multiline.txt", "Line 1\nLine 2\nLine 3\n"), # Multi-line content @@ -513,7 +507,6 @@ def test_different_content_types_empty( collect_server_logs, temp_credentials_file, project_root ): """Test encryption/decryption with different types of content.""" - test_cases = [ ("empty.txt", ""), # Empty file ] diff --git a/tests/integration/test_pe_interaction.py b/tests/integration/test_pe_interaction.py index 80ebaee..ecf839a 100644 --- a/tests/integration/test_pe_interaction.py +++ b/tests/integration/test_pe_interaction.py @@ -1,6 +1,4 @@ -""" -Integration test: Single attribute encryption/decryption using SDK and otdfctl -""" +"""Integration test: Single attribute encryption/decryption using SDK and otdfctl""" import logging import tempfile diff --git a/tests/server_logs.py b/tests/server_logs.py index 0bc846d..a5521c4 100644 --- a/tests/server_logs.py +++ b/tests/server_logs.py @@ -1,6 +1,4 @@ -""" -Server log collection utility for debugging test failures. -""" +"""Server log collection utility for debugging test failures.""" import logging import subprocess @@ -17,17 +15,18 @@ def collect_server_logs( lines: int = CONFIG_TESTING.LOG_LINES, test_name: str | None = None, ) -> str | None: - """ - Collect server logs from a Kubernetes pod via SSH. + """Collect server logs from a Kubernetes pod via SSH. Args: pod_name: Name of the Kubernetes pod namespace: Kubernetes namespace ssh_target: SSH target (hostname/alias) lines: Number of log lines to retrieve + test_name: Test name for log file naming Returns: Log output as string, or None if collection failed + """ if CONFIG_TESTING.ENABLE_LOG_COLLECTION: logging.debug(f"\n{'=' * 60}") @@ -73,8 +72,7 @@ def log_server_logs_on_failure( ssh_target: str = CONFIG_TESTING.SSH_TARGET, lines: int = CONFIG_TESTING.LOG_LINES, ) -> None: - """ - Collect and log server logs when a test fails. + """Collect and log server logs when a test fails. Args: test_name: Name of the failed test @@ -82,8 +80,8 @@ def log_server_logs_on_failure( namespace: Kubernetes namespace ssh_target: SSH target (hostname/alias) lines: Number of log lines to retrieve - """ + """ logs = collect_server_logs(pod_name, namespace, ssh_target, lines, test_name) if logs: diff --git a/tests/support_cli_args.py b/tests/support_cli_args.py index ab2900d..5975a49 100644 --- a/tests/support_cli_args.py +++ b/tests/support_cli_args.py @@ -1,6 +1,4 @@ -""" -Support functions for constructing CLI arguments for this project's (Python) CLI. -""" +"""Support functions for constructing CLI arguments for this project's (Python) CLI.""" import json import logging @@ -15,9 +13,7 @@ def _get_cli_flags() -> list[str]: - """ - Determine (Python) CLI flags based on platform URL - """ + """Determine (Python) CLI flags based on platform URL""" platform_url = get_platform_url() cli_flags = [] @@ -32,13 +28,11 @@ def _get_cli_flags() -> list[str]: def run_cli_inspect(tdf_path: Path, creds_file: Path, cwd: Path) -> dict: - """ - Helper function to run Python CLI inspect command and return parsed JSON result. + """Helper function to run Python CLI inspect command and return parsed JSON result. This demonstrates how the CLI inspect functionality could be tested with the new fixtures. """ - # Build CLI command cmd = [ sys.executable, diff --git a/tests/support_common.py b/tests/support_common.py index 153cf9e..92ea364 100644 --- a/tests/support_common.py +++ b/tests/support_common.py @@ -37,8 +37,7 @@ def handle_subprocess_error( def get_testing_environ() -> dict | None: - """ - Set up environment and configuration + """Set up environment and configuration TODO: YAGNI: this is a hook we could use to modify all testing environments, e.g. env = os.environ.copy() diff --git a/tests/support_otdfctl.py b/tests/support_otdfctl.py index 80f411f..dd24394 100644 --- a/tests/support_otdfctl.py +++ b/tests/support_otdfctl.py @@ -6,13 +6,11 @@ @pytest.mark.integration @pytest.fixture(scope="session", autouse=True) def check_for_otdfctl(): - """ - Ensure that the otdfctl command is available on the system. + """Ensure that the otdfctl command is available on the system. This fixture runs once per test session (for integration tests) and raises an exception if the otdfctl command is not found. """ - # Check if otdfctl is available try: subprocess.run( diff --git a/tests/support_otdfctl_args.py b/tests/support_otdfctl_args.py index 99dcc4a..f5dd08e 100644 --- a/tests/support_otdfctl_args.py +++ b/tests/support_otdfctl_args.py @@ -1,6 +1,4 @@ -""" -Support functions for constructing CLI arguments for otdfctl CLI. -""" +"""Support functions for constructing CLI arguments for otdfctl CLI.""" import logging import subprocess @@ -13,9 +11,7 @@ def get_otdfctl_flags() -> list[str]: - """ - Determine otdfctl flags based on platform URL - """ + """Determine otdfctl flags based on platform URL""" platform_url = get_platform_url() otdfctl_flags = [] if platform_url.startswith("http://"): @@ -68,8 +64,8 @@ def _build_otdfctl_encrypt_command( attributes: Optional list of attributes to apply tdf_type: TDF type (e.g., "tdf3", "nano") target_mode: Target TDF spec version (e.g., "v4.2.2", "v4.3.1") - """ + """ cmd = get_otdfctl_base_command(creds_file, platform_url) cmd.append("encrypt") cmd.extend(["--mime-type", mime_type]) @@ -215,8 +211,7 @@ def otdfctl_generate_tdf_files_for_target_mode( test_data_dir: Path, sample_input_files: dict[str, Path], ) -> dict[str, Path]: - """ - Factory function to generate TDF files for a specific target mode. + """Factory function to generate TDF files for a specific target mode. Args: target_mode: Target TDF spec version (e.g., "v4.2.2", "v4.3.1") @@ -226,6 +221,7 @@ def otdfctl_generate_tdf_files_for_target_mode( Returns: Dictionary mapping file types to their TDF file paths + """ output_dir = test_data_dir / target_mode tdf_files = {} diff --git a/tests/test_address_normalizer.py b/tests/test_address_normalizer.py index 01577f0..cfcdea3 100644 --- a/tests/test_address_normalizer.py +++ b/tests/test_address_normalizer.py @@ -1,6 +1,4 @@ -""" -Tests for address_normalizer module. -""" +"""Tests for address_normalizer module.""" import pytest diff --git a/tests/test_cli.py b/tests/test_cli.py index b1d8b4d..5053ca8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,4 @@ -""" -Test CLI functionality -""" +"""Test CLI functionality""" import subprocess import sys diff --git a/tests/test_ecdh.py b/tests/test_ecdh.py index 537d1a9..6dba667 100644 --- a/tests/test_ecdh.py +++ b/tests/test_ecdh.py @@ -1,6 +1,4 @@ -""" -Unit tests for ECDH key exchange module. -""" +"""Unit tests for ECDH key exchange module.""" import pytest from cryptography.hazmat.primitives import serialization diff --git a/tests/test_kas_client.py b/tests/test_kas_client.py index b4c3b17..3591699 100644 --- a/tests/test_kas_client.py +++ b/tests/test_kas_client.py @@ -1,6 +1,4 @@ -""" -Unit tests for KASClient. -""" +"""Unit tests for KASClient.""" from base64 import b64decode from unittest.mock import MagicMock, patch @@ -13,7 +11,19 @@ class MockKasInfo: + """Mock KAS info for testing.""" + def __init__(self, url, algorithm=None, public_key=None, kid=None, default=False): + """Initialize MockKasInfo. + + Args: + url: KAS URL. + algorithm: Key algorithm. + public_key: Public key. + kid: Key ID. + default: Whether this is the default KAS. + + """ self.url = url self.algorithm = algorithm or "" self.public_key = public_key or "" @@ -31,6 +41,7 @@ def clone(self): def test_get_public_key_uses_cache(): + """Test that get_public_key uses cached KAS info.""" cache = KASKeyCache() kas_info = MockKasInfo(url="http://kas") # Store in cache using the new mechanism @@ -45,6 +56,7 @@ def test_get_public_key_uses_cache(): def test_get_public_key_fetches_and_caches( mock_access_service_client, mock_pool_manager ): + """Test that get_public_key fetches and caches public key.""" cache = KASKeyCache() client = KASClient("http://kas", cache=cache) @@ -105,6 +117,7 @@ def test_unwrap_success( mock_access_service_client, mock_pool_manager, ): + """Test successful key unwrap operation.""" # Setup mocks for RSA key pair generation and decryption mock_private_key = MagicMock() mock_public_key = MagicMock() @@ -168,6 +181,7 @@ def test_unwrap_success( @patch("urllib3.PoolManager") @patch("otdf_python_proto.kas.kas_pb2_connect.AccessServiceClient") def test_unwrap_failure(mock_access_service_client, mock_pool_manager): + """Test key unwrap failure handling.""" # Setup realistic HTTP response mock for PoolManager mock_response = MagicMock() mock_response.status = 500 diff --git a/tests/test_kas_key_cache.py b/tests/test_kas_key_cache.py index b4e6cc7..0bd740b 100644 --- a/tests/test_kas_key_cache.py +++ b/tests/test_kas_key_cache.py @@ -1,6 +1,4 @@ -""" -Unit tests for KASKeyCache. -""" +"""Unit tests for KASKeyCache.""" from dataclasses import dataclass @@ -9,6 +7,8 @@ @dataclass class MockKasInfo: + """Mock KAS info for testing.""" + url: str algorithm: str | None = None public_key: str | None = None @@ -17,6 +17,7 @@ class MockKasInfo: def test_kas_key_cache_set_and_get(): + """Test KASKeyCache set and get operations.""" cache = KASKeyCache() # Use the new store/get interface kas_info = MockKasInfo(url="http://example.com") @@ -25,6 +26,7 @@ def test_kas_key_cache_set_and_get(): def test_kas_key_cache_overwrite(): + """Test KASKeyCache overwriting cached values.""" cache = KASKeyCache() # Test overwriting with new values kas_info1 = MockKasInfo(url="http://example.com") @@ -38,6 +40,7 @@ def test_kas_key_cache_overwrite(): def test_kas_key_cache_clear(): + """Test KASKeyCache clear operation.""" cache = KASKeyCache() cache.set("key1", "value1") cache.clear() diff --git a/tests/test_kas_key_management.py b/tests/test_kas_key_management.py index 3921de5..110972b 100644 --- a/tests/test_kas_key_management.py +++ b/tests/test_kas_key_management.py @@ -1,3 +1,5 @@ +"""Tests for KAS key management.""" + import base64 import os import unittest diff --git a/tests/test_key_type.py b/tests/test_key_type.py index e899b90..f5d1339 100644 --- a/tests/test_key_type.py +++ b/tests/test_key_type.py @@ -1,14 +1,20 @@ +"""Tests for KeyType.""" + import unittest from otdf_python.key_type import KeyType class TestKeyType(unittest.TestCase): + """Tests for KeyType class.""" + def test_str(self): + """Test KeyType string representation.""" self.assertEqual(str(KeyType.RSA2048Key), "rsa:2048") self.assertEqual(str(KeyType.EC256Key), "ec:secp256r1") def test_get_curve_name(self): + """Test KeyType get_curve_name method.""" self.assertEqual(KeyType.EC256Key.get_curve_name(), "secp256r1") self.assertEqual(KeyType.EC384Key.get_curve_name(), "secp384r1") self.assertEqual(KeyType.EC521Key.get_curve_name(), "secp521r1") @@ -16,12 +22,14 @@ def test_get_curve_name(self): KeyType.RSA2048Key.get_curve_name() def test_from_string(self): + """Test KeyType from_string method.""" self.assertEqual(KeyType.from_string("rsa:2048"), KeyType.RSA2048Key) self.assertEqual(KeyType.from_string("ec:secp256r1"), KeyType.EC256Key) with self.assertRaises(ValueError): KeyType.from_string("notakey") def test_is_ec(self): + """Test KeyType is_ec method.""" self.assertTrue(KeyType.EC256Key.is_ec()) self.assertFalse(KeyType.RSA2048Key.is_ec()) diff --git a/tests/test_log_collection.py b/tests/test_log_collection.py index 1845290..a0c8643 100644 --- a/tests/test_log_collection.py +++ b/tests/test_log_collection.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -""" -Test script to verify server log collection functionality. +"""Test script to verify server log collection functionality. This script tests the server log collection without running full pytest. """ diff --git a/tests/test_manifest.py b/tests/test_manifest.py index f9e36d1..864601c 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -1,3 +1,5 @@ +"""Tests for TDF manifest.""" + from otdf_python.manifest import ( Manifest, ManifestAssertion, diff --git a/tests/test_manifest_format.py b/tests/test_manifest_format.py index 7be368d..9b61431 100644 --- a/tests/test_manifest_format.py +++ b/tests/test_manifest_format.py @@ -1,6 +1,4 @@ -""" -Test TDF manifest format, inspired by the Java SDK manifest tests. -""" +"""Test TDF manifest format, inspired by the Java SDK manifest tests.""" import json @@ -11,7 +9,6 @@ def test_manifest_field_format(): """Test that manifest uses camelCase field names as per TDF specification.""" - # Create a mock KAS info with public key to avoid network calls kas_private_key, kas_public_key = generate_rsa_keypair() kas_info = KASInfo( diff --git a/tests/test_nanotdf.py b/tests/test_nanotdf.py index 31db9fc..23acbd7 100644 --- a/tests/test_nanotdf.py +++ b/tests/test_nanotdf.py @@ -1,3 +1,5 @@ +"""Tests for NanoTDF.""" + import secrets import pytest @@ -7,6 +9,7 @@ def test_nanotdf_roundtrip(): + """Test NanoTDF encrypt and decrypt roundtrip.""" nanotdf = NanoTDF() key = secrets.token_bytes(32) data = b"nano tdf test payload" @@ -18,6 +21,7 @@ def test_nanotdf_roundtrip(): def test_nanotdf_too_large(): + """Test NanoTDF with payload exceeding size limit.""" nanotdf = NanoTDF() key = secrets.token_bytes(32) data = b"x" * (NanoTDF.K_MAX_TDF_SIZE + 1) @@ -27,6 +31,7 @@ def test_nanotdf_too_large(): def test_nanotdf_invalid_magic(): + """Test NanoTDF with invalid magic bytes.""" nanotdf = NanoTDF() key = secrets.token_bytes(32) config = NanoTDFConfig(cipher=key.hex()) @@ -37,6 +42,7 @@ def test_nanotdf_invalid_magic(): @pytest.mark.integration def test_nanotdf_integration_encrypt_decrypt(): + """Test NanoTDF integration with KAS.""" # Load environment variables for integration from otdf_python.config import KASInfo from tests.config_pydantic import CONFIG_TDF diff --git a/tests/test_nanotdf_ecdh.py b/tests/test_nanotdf_ecdh.py index f11f5b3..d1fec60 100644 --- a/tests/test_nanotdf_ecdh.py +++ b/tests/test_nanotdf_ecdh.py @@ -1,6 +1,4 @@ -""" -Integration tests for NanoTDF with ECDH key exchange. -""" +"""Integration tests for NanoTDF with ECDH key exchange.""" import io diff --git a/tests/test_nanotdf_ecdsa_struct.py b/tests/test_nanotdf_ecdsa_struct.py index d83eb16..994f9d7 100644 --- a/tests/test_nanotdf_ecdsa_struct.py +++ b/tests/test_nanotdf_ecdsa_struct.py @@ -1,6 +1,4 @@ -""" -Tests for NanoTDFECDSAStruct. -""" +"""Tests for NanoTDFECDSAStruct.""" import pytest diff --git a/tests/test_nanotdf_integration.py b/tests/test_nanotdf_integration.py index a4bf79e..945d020 100644 --- a/tests/test_nanotdf_integration.py +++ b/tests/test_nanotdf_integration.py @@ -1,3 +1,5 @@ +"""Tests for NanoTDF integration.""" + import io import pytest @@ -10,6 +12,7 @@ @pytest.mark.integration def test_nanotdf_kas_roundtrip(): + """Test NanoTDF KAS integration roundtrip.""" # Generate EC keypair (NanoTDF uses ECDH, not RSA) private_key = ec.generate_private_key(ec.SECP256R1()) private_pem = private_key.private_bytes( diff --git a/tests/test_nanotdf_type.py b/tests/test_nanotdf_type.py index c93c8b8..b7d788c 100644 --- a/tests/test_nanotdf_type.py +++ b/tests/test_nanotdf_type.py @@ -1,3 +1,5 @@ +"""Tests for NanoTDF types.""" + import unittest from otdf_python.nanotdf_type import ( @@ -10,23 +12,29 @@ class TestNanoTDFType(unittest.TestCase): + """Tests for NanoTDF type enums.""" + def test_eccurve(self): + """Test ECCurve enum values.""" self.assertEqual(str(ECCurve.SECP256R1), "secp256r1") self.assertEqual(str(ECCurve.SECP384R1), "secp384r1") self.assertEqual(str(ECCurve.SECP521R1), "secp384r1") self.assertEqual(str(ECCurve.SECP256K1), "secp256k1") def test_protocol(self): + """Test Protocol enum values.""" self.assertEqual(Protocol.HTTP.value, "HTTP") self.assertEqual(Protocol.HTTPS.value, "HTTPS") def test_identifier_type(self): + """Test IdentifierType enum values.""" self.assertEqual(IdentifierType.NONE.get_length(), 0) self.assertEqual(IdentifierType.TWO_BYTES.get_length(), 2) self.assertEqual(IdentifierType.EIGHT_BYTES.get_length(), 8) self.assertEqual(IdentifierType.THIRTY_TWO_BYTES.get_length(), 32) def test_policy_type(self): + """Test PolicyType enum values.""" self.assertEqual(PolicyType.REMOTE_POLICY.value, 0) self.assertEqual(PolicyType.EMBEDDED_POLICY_PLAIN_TEXT.value, 1) self.assertEqual(PolicyType.EMBEDDED_POLICY_ENCRYPTED.value, 2) @@ -35,6 +43,7 @@ def test_policy_type(self): ) def test_cipher(self): + """Test Cipher enum values.""" self.assertEqual(Cipher.AES_256_GCM_64_TAG.value, 0) self.assertEqual(Cipher.AES_256_GCM_128_TAG.value, 5) self.assertEqual(Cipher.EAD_AES_256_HMAC_SHA_256.value, 6) diff --git a/tests/test_policy_object.py b/tests/test_policy_object.py index a0ceb01..80091a4 100644 --- a/tests/test_policy_object.py +++ b/tests/test_policy_object.py @@ -1,10 +1,15 @@ +"""Tests for policy objects.""" + import unittest from otdf_python.policy_object import AttributeObject, PolicyBody, PolicyObject class TestPolicyObject(unittest.TestCase): + """Tests for policy object classes.""" + def test_attribute_object(self): + """Test AttributeObject creation and properties.""" attr = AttributeObject( attribute="attr1", display_name="Attribute 1", @@ -19,6 +24,7 @@ def test_attribute_object(self): self.assertEqual(attr.kas_url, "https://kas.example.com") def test_policy_body(self): + """Test PolicyBody creation and properties.""" attr1 = AttributeObject(attribute="attr1") attr2 = AttributeObject(attribute="attr2") body = PolicyBody(data_attributes=[attr1, attr2], dissem=["user1", "user2"]) @@ -27,6 +33,7 @@ def test_policy_body(self): self.assertIn("user2", body.dissem) def test_policy_object(self): + """Test PolicyObject creation and properties.""" attr = AttributeObject(attribute="attr1") body = PolicyBody(data_attributes=[attr], dissem=["user1"]) policy = PolicyObject(uuid="uuid-1234", body=body) diff --git a/tests/test_sdk.py b/tests/test_sdk.py index 9c217f1..5d84c54 100644 --- a/tests/test_sdk.py +++ b/tests/test_sdk.py @@ -1,19 +1,22 @@ -""" -Basic tests for the Python SDK class port. -""" +"""Basic tests for the Python SDK class.""" from otdf_python.sdk import SDK class DummyServices(SDK.Services): + """Dummy SDK services for testing.""" + def close(self): + """Close the services.""" self.closed = True def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context manager.""" pass def test_sdk_init_and_close(): + """Test SDK initialization and close.""" services = DummyServices() sdk = SDK(services) assert sdk.get_services() is services @@ -25,6 +28,7 @@ def test_sdk_init_and_close(): def test_split_key_exception(): + """Test SDK SplitKeyException.""" try: raise SDK.SplitKeyException("split key error") except SDK.SplitKeyException: @@ -32,6 +36,7 @@ def test_split_key_exception(): def test_data_size_not_supported(): + """Test SDK DataSizeNotSupported exception.""" try: raise SDK.DataSizeNotSupported("too large") except SDK.DataSizeNotSupported: @@ -39,6 +44,7 @@ def test_data_size_not_supported(): def test_kas_info_missing(): + """Test SDK KasInfoMissing exception.""" try: raise SDK.KasInfoMissing("kas info missing") except SDK.KasInfoMissing: @@ -46,6 +52,7 @@ def test_kas_info_missing(): def test_kas_public_key_missing(): + """Test SDK KasPublicKeyMissing exception.""" try: raise SDK.KasPublicKeyMissing("kas pubkey missing") except SDK.KasPublicKeyMissing: @@ -53,6 +60,7 @@ def test_kas_public_key_missing(): def test_tamper_exception(): + """Test SDK TamperException.""" try: raise SDK.TamperException("tamper") except SDK.TamperException: @@ -60,6 +68,7 @@ def test_tamper_exception(): def test_root_signature_validation_exception(): + """Test SDK RootSignatureValidationException.""" try: raise SDK.RootSignatureValidationException("root sig") except SDK.RootSignatureValidationException: @@ -67,6 +76,7 @@ def test_root_signature_validation_exception(): def test_segment_signature_mismatch(): + """Test SDK SegmentSignatureMismatch exception.""" try: raise SDK.SegmentSignatureMismatch("seg sig") except SDK.SegmentSignatureMismatch: @@ -74,6 +84,7 @@ def test_segment_signature_mismatch(): def test_kas_bad_request_exception(): + """Test SDK KasBadRequestException.""" try: raise SDK.KasBadRequestException("kas bad req") except SDK.KasBadRequestException: @@ -81,6 +92,7 @@ def test_kas_bad_request_exception(): def test_kas_allowlist_exception(): + """Test SDK KasAllowlistException.""" try: raise SDK.KasAllowlistException("kas allowlist") except SDK.KasAllowlistException: @@ -88,6 +100,7 @@ def test_kas_allowlist_exception(): def test_assertion_exception(): + """Test SDK AssertionException.""" try: raise SDK.AssertionException("assertion", "id123") except SDK.AssertionException: diff --git a/tests/test_sdk_builder.py b/tests/test_sdk_builder.py index 756b2c7..2b499f6 100644 --- a/tests/test_sdk_builder.py +++ b/tests/test_sdk_builder.py @@ -1,6 +1,4 @@ -""" -Tests for the SDKBuilder class. -""" +"""Tests for the SDKBuilder class.""" import tempfile from pathlib import Path diff --git a/tests/test_sdk_exceptions.py b/tests/test_sdk_exceptions.py index 92ddf4b..4c52c99 100644 --- a/tests/test_sdk_exceptions.py +++ b/tests/test_sdk_exceptions.py @@ -1,15 +1,21 @@ +"""Tests for SDK exceptions.""" + import unittest from otdf_python.sdk_exceptions import AutoConfigureException, SDKException class TestSDKExceptions(unittest.TestCase): + """Tests for SDK exception classes.""" + def test_sdk_exception(self): + """Test SDKException creation and properties.""" e = SDKException("msg", Exception("reason")) self.assertEqual(str(e), "msg") self.assertIsInstance(e.reason, Exception) def test_auto_configure_exception(self): + """Test AutoConfigureException creation and properties.""" e = AutoConfigureException("fail", Exception("cause")) self.assertEqual(str(e), "fail") self.assertIsInstance(e.reason, Exception) diff --git a/tests/test_sdk_mock.py b/tests/test_sdk_mock.py index d7c574f..5e3f27f 100644 --- a/tests/test_sdk_mock.py +++ b/tests/test_sdk_mock.py @@ -1,29 +1,42 @@ +"""Mock SDK components for testing.""" + from otdf_python.sdk import KAS, SDK class MockKAS(KAS): + """Mock KAS implementation for testing.""" + def get_public_key(self, kas_info): + """Return mock public key.""" return "mock-public-key" def get_ec_public_key(self, kas_info, curve): + """Return mock EC public key.""" return "mock-ec-public-key" def unwrap(self, key_access, policy, session_key_type): + """Return mock unwrapped key.""" return b"mock-unwrapped-key" def unwrap_nanotdf(self, curve, header, kas_url): + """Return mock unwrapped NanoTDF key.""" return b"mock-unwrapped-nanotdf" def get_key_cache(self): + """Return mock key cache.""" return None class MockServices(SDK.Services): + """Mock SDK services for testing.""" + def kas(self): + """Return mock KAS instance.""" return MockKAS() def test_sdk_instantiation(): + """Test SDK instantiation with mock services.""" services = MockServices() sdk = SDK(services=services) assert sdk.get_services() is services diff --git a/tests/test_sdk_tdf_integration.py b/tests/test_sdk_tdf_integration.py index 85b2a85..c8143e9 100644 --- a/tests/test_sdk_tdf_integration.py +++ b/tests/test_sdk_tdf_integration.py @@ -1,6 +1,4 @@ -""" -Tests for the integration between SDK and TDF classes. -""" +"""Tests for the integration between SDK and TDF classes.""" import io diff --git a/tests/test_tdf.py b/tests/test_tdf.py index 1dcc8c9..650eb61 100644 --- a/tests/test_tdf.py +++ b/tests/test_tdf.py @@ -1,3 +1,5 @@ +"""Tests for TDF.""" + import io import json import zipfile @@ -11,6 +13,7 @@ def test_tdf_create_and_load(): + """Test TDF creation and loading roundtrip.""" tdf = TDF() payload = b"test payload" kas_private_key, kas_public_key = generate_rsa_keypair() @@ -40,6 +43,7 @@ def test_tdf_create_and_load(): @pytest.mark.integration def test_tdf_multi_kas_roundtrip(): + """Test TDF with multiple KAS roundtrip.""" tdf = TDF() payload = b"multi-kas test payload" # Generate two KAS keypairs diff --git a/tests/test_tdf_key_management.py b/tests/test_tdf_key_management.py index f050577..4f749d3 100644 --- a/tests/test_tdf_key_management.py +++ b/tests/test_tdf_key_management.py @@ -1,3 +1,5 @@ +"""Tests for TDF key management.""" + import base64 import io import unittest diff --git a/tests/test_tdf_reader.py b/tests/test_tdf_reader.py index 663ccea..1c60511 100644 --- a/tests/test_tdf_reader.py +++ b/tests/test_tdf_reader.py @@ -1,6 +1,4 @@ -""" -Tests for TDFReader. -""" +"""Tests for TDFReader.""" import io import json diff --git a/tests/test_tdf_writer.py b/tests/test_tdf_writer.py index 4d6ff79..354045a 100644 --- a/tests/test_tdf_writer.py +++ b/tests/test_tdf_writer.py @@ -1,3 +1,5 @@ +"""Tests for TDFWriter.""" + import io import unittest import zipfile @@ -6,7 +8,10 @@ class TestTDFWriter(unittest.TestCase): + """Tests for TDFWriter class.""" + def test_append_manifest_and_payload(self): + """Test appending manifest and payload.""" out = io.BytesIO() writer = TDFWriter(out) manifest = '{"foo": "bar"}' @@ -21,6 +26,7 @@ def test_append_manifest_and_payload(self): self.assertEqual(z.read("0.payload"), b"payload data") def test_getvalue(self): + """Test getting writer value as bytes.""" writer = TDFWriter() writer.append_manifest("{}") with writer.payload() as f: @@ -32,6 +38,7 @@ def test_getvalue(self): self.assertEqual(z.read("0.payload"), b"abc") def test_large_payload_chunks(self): + """Test writing large payload in chunks.""" out = io.BytesIO() writer = TDFWriter(out) writer.append_manifest('{"test": true}') @@ -45,6 +52,7 @@ def test_large_payload_chunks(self): self.assertEqual(z.read("0.payload"), chunk * 5) def test_error_on_write_after_finish(self): + """Test error when writing after finish.""" out = io.BytesIO() writer = TDFWriter(out) writer.append_manifest("{}") diff --git a/tests/test_token_source.py b/tests/test_token_source.py index 5bc9e28..ae06d9e 100644 --- a/tests/test_token_source.py +++ b/tests/test_token_source.py @@ -1,6 +1,4 @@ -""" -Unit tests for TokenSource. -""" +"""Unit tests for TokenSource.""" import time from unittest.mock import MagicMock, patch @@ -9,6 +7,7 @@ def test_token_source_returns_token_and_caches(): + """Test TokenSource returns token and caches it.""" with patch("httpx.post") as mock_post: mock_resp = MagicMock() mock_resp.json.return_value = {"access_token": "abc", "expires_in": 100} @@ -26,6 +25,7 @@ def test_token_source_returns_token_and_caches(): @patch("httpx.post") def test_token_source_refreshes_token(mock_post): + """Test TokenSource refreshes expired token.""" mock_resp1 = MagicMock() mock_resp1.json.return_value = {"access_token": "abc", "expires_in": 1} mock_resp1.raise_for_status.return_value = None diff --git a/tests/test_url_normalization.py b/tests/test_url_normalization.py index a5757fc..23504d6 100644 --- a/tests/test_url_normalization.py +++ b/tests/test_url_normalization.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -""" -Test script to verify URL normalization functionality is working correctly. +"""Test script to verify URL normalization functionality is working correctly. This script tests the _normalize_kas_url method to ensure it properly respects the use_plaintext setting when converting URLs. @@ -10,6 +9,7 @@ def test_url_normalization(): + """Test KAS URL normalization with plaintext settings.""" print("Testing with use_plaintext=True:") client_plaintext = KASClient(use_plaintext=True) diff --git a/tests/test_use_plaintext_flow.py b/tests/test_use_plaintext_flow.py index 3019bed..d882440 100644 --- a/tests/test_use_plaintext_flow.py +++ b/tests/test_use_plaintext_flow.py @@ -1,6 +1,4 @@ -""" -Test to verify that the use_plaintext parameter flows correctly from SDKBuilder to KASClient. -""" +"""Test to verify that the use_plaintext parameter flows correctly from SDKBuilder to KASClient.""" from unittest.mock import MagicMock, patch @@ -9,7 +7,6 @@ def test_use_plaintext_flows_through_sdk_builder_to_kas_client(): """Test that use_plaintext parameter flows from SDKBuilder through to KASClient.""" - with patch("otdf_python.kas_client.KASClient") as mock_kas_client: # Mock the KASClient constructor to capture the arguments mock_kas_instance = MagicMock() @@ -40,7 +37,6 @@ def test_use_plaintext_flows_through_sdk_builder_to_kas_client(): def test_use_plaintext_false_flows_through_sdk_builder_to_kas_client(): """Test that use_plaintext=False flows from SDKBuilder through to KASClient.""" - with patch("otdf_python.kas_client.KASClient") as mock_kas_client: # Mock the KASClient constructor to capture the arguments mock_kas_instance = MagicMock() @@ -71,7 +67,6 @@ def test_use_plaintext_false_flows_through_sdk_builder_to_kas_client(): def test_use_plaintext_default_value(): """Test that the default use_plaintext value is False.""" - with patch("otdf_python.kas_client.KASClient") as mock_kas_client: # Mock the KASClient constructor to capture the arguments mock_kas_instance = MagicMock() diff --git a/tests/test_validate_otdf_python.py b/tests/test_validate_otdf_python.py index 1ac0818..9906f14 100644 --- a/tests/test_validate_otdf_python.py +++ b/tests/test_validate_otdf_python.py @@ -1,5 +1,6 @@ -""" -This file is effectively the same test coverage as: +"""Validation tests for OpenTDF Python SDK. + +This module provides the same test coverage as: https://github.com/b-long/opentdf-python-sdk/blob/v0.2.17/validate_otdf_python.py Execute using: @@ -35,7 +36,6 @@ def _get_sdk_and_tdf_config() -> tuple: def encrypt_file(input_path: Path) -> Path: """Encrypt a file and return the path to the encrypted file.""" - # Build the SDK sdk, tdf_config = _get_sdk_and_tdf_config() @@ -59,6 +59,7 @@ def decrypt_file(encrypted_path: Path) -> Path: def verify_encrypt_str() -> None: + """Verify string encryption functionality.""" print("Validating string encryption (local TDF)") try: sdk = get_sdk() @@ -97,6 +98,7 @@ def test_verify_encrypt_str(): def verify_encrypt_file() -> None: + """Verify file encryption functionality.""" print("Validating file encryption (local TDF)") try: with tempfile.TemporaryDirectory() as tmpDir: @@ -121,6 +123,7 @@ def test_verify_encrypt_file(): def verify_encrypt_decrypt_file() -> None: + """Verify encrypt/decrypt roundtrip functionality.""" print("Validating encrypt/decrypt roundtrip (local TDF)") try: with tempfile.TemporaryDirectory() as tmpDir: diff --git a/tests/test_version.py b/tests/test_version.py index c4740cf..eabbeec 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,10 +1,15 @@ +"""Tests for Version.""" + import unittest from otdf_python.version import Version class TestVersion(unittest.TestCase): + """Tests for Version class.""" + def test_parse_and_str(self): + """Test Version parsing and string representation.""" v = Version("1.2.3-alpha") self.assertEqual(v.major, 1) self.assertEqual(v.minor, 2) @@ -13,6 +18,7 @@ def test_parse_and_str(self): self.assertIn("Version{major=1, minor=2, patch=3", str(v)) def test_compare(self): + """Test Version comparison.""" v1 = Version("1.2.3") v2 = Version("1.2.4") v3 = Version("1.3.0") @@ -24,6 +30,7 @@ def test_compare(self): self.assertEqual(v1, Version(1, 2, 3)) def test_hash(self): + """Test Version hashing.""" v1 = Version("1.2.3") v2 = Version(1, 2, 3) self.assertEqual(hash(v1), hash(v2)) @@ -31,6 +38,7 @@ def test_hash(self): self.assertEqual(len(s), 1) def test_invalid(self): + """Test invalid Version string.""" with self.assertRaises(ValueError): Version("not.a.version") diff --git a/tests/test_zip_reader.py b/tests/test_zip_reader.py index cde645e..23517c2 100644 --- a/tests/test_zip_reader.py +++ b/tests/test_zip_reader.py @@ -1,3 +1,5 @@ +"""Tests for ZipReader.""" + import io import random import unittest @@ -8,7 +10,10 @@ class TestZipReader(unittest.TestCase): + """Tests for ZipReader class.""" + def test_read_and_namelist(self): + """Test reading zip and listing files.""" # Create a zip in memory writer = ZipWriter() writer.data("foo.txt", b"foo") @@ -25,6 +30,7 @@ def test_read_and_namelist(self): reader.close() def test_extract(self): + """Test extracting files from zip.""" import tempfile writer = ZipWriter() @@ -39,6 +45,7 @@ def test_extract(self): reader.close() def test_entry_interface_and_random_files(self): + """Test zip entry interface with random files.""" # Create a zip with many random files r = random.Random(42) num_entries = r.randint(10, 20) # Use a smaller number for test speed diff --git a/tests/test_zip_writer.py b/tests/test_zip_writer.py index bf7dccd..f663735 100644 --- a/tests/test_zip_writer.py +++ b/tests/test_zip_writer.py @@ -1,3 +1,5 @@ +"""Tests for ZipWriter.""" + import io import unittest import zipfile @@ -6,7 +8,10 @@ class TestZipWriter(unittest.TestCase): + """Tests for ZipWriter class.""" + def test_data_and_stream(self): + """Test writing data and streams to zip.""" out = io.BytesIO() writer = ZipWriter(out) # Write using data @@ -23,6 +28,7 @@ def test_data_and_stream(self): self.assertEqual(z.read("bar.txt"), b"bar contents") def test_getvalue(self): + """Test getting writer value as bytes.""" writer = ZipWriter() writer.data("a.txt", b"A") writer.finish()