Skip to content
183 changes: 146 additions & 37 deletions sagemaker-core/src/sagemaker/core/jumpstart/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,78 @@
logger = logging.getLogger(__name__)


class _ExpressionNode:
"""Base class for expression AST nodes."""

def evaluate(self, keywords: List[str]) -> bool:
"""Evaluate this node against the given keywords."""
raise NotImplementedError


class _AndNode(_ExpressionNode):
"""AND logical operator node."""

def __init__(self, left: _ExpressionNode, right: _ExpressionNode):
self.left = left
self.right = right

def evaluate(self, keywords: List[str]) -> bool:
return self.left.evaluate(keywords) and self.right.evaluate(keywords)


class _OrNode(_ExpressionNode):
"""OR logical operator node."""

def __init__(self, left: _ExpressionNode, right: _ExpressionNode):
self.left = left
self.right = right

def evaluate(self, keywords: List[str]) -> bool:
return self.left.evaluate(keywords) or self.right.evaluate(keywords)


class _NotNode(_ExpressionNode):
"""NOT logical operator node."""

def __init__(self, operand: _ExpressionNode):
self.operand = operand

def evaluate(self, keywords: List[str]) -> bool:
return not self.operand.evaluate(keywords)


class _PatternNode(_ExpressionNode):
"""Pattern matching node for keywords with wildcard support."""

def __init__(self, pattern: str):
self.pattern = pattern.strip('"').strip("'")

def evaluate(self, keywords: List[str]) -> bool:
"""Check if any keyword matches this pattern."""
for keyword in keywords:
if self._matches_pattern(keyword, self.pattern):
return True
return False

def _matches_pattern(self, keyword: str, pattern: str) -> bool:
"""Check if a keyword matches a pattern with wildcard support."""
if pattern.startswith("*") and pattern.endswith("*"):
# Contains pattern: *text*
stripped = pattern.strip("*")
return stripped in keyword
elif pattern.startswith("*"):
# Ends with pattern: *text
stripped = pattern[1:]
return keyword.endswith(stripped)
elif pattern.endswith("*"):
# Starts with pattern: text*
stripped = pattern[:-1]
return keyword.startswith(stripped)
else:
# Exact match
return keyword == pattern


class _Filter:
"""
A filter that evaluates logical expressions against a list of keyword strings.
Expand All @@ -28,6 +100,7 @@ def __init__(self, expression: str) -> None:
Supports AND, OR, NOT, parentheses, and wildcard patterns (*).
"""
self.expression: str = expression
self._ast: Optional[_ExpressionNode] = None

def match(self, keywords: List[str]) -> bool:
"""
Expand All @@ -39,54 +112,90 @@ def match(self, keywords: List[str]) -> bool:
Returns:
bool: True if the expression evaluates to True for the given keywords, else False.
"""
expr: str = self._convert_expression(self.expression)
try:
return eval(expr, {"__builtins__": {}}, {"keywords": keywords, "any": any})
if self._ast is None:
self._ast = self._parse_expression(self.expression)
return self._ast.evaluate(keywords)
except Exception:
return False

def _convert_expression(self, expr: str) -> str:
def _parse_expression(self, expr: str) -> _ExpressionNode:
"""
Convert the logical filter expression into a Python-evaluable string.
Parse the logical filter expression into an AST.

Args:
expr (str): The raw expression to convert.
expr (str): The raw expression to parse.

Returns:
str: A Python expression string using 'any' and logical operators.
_ExpressionNode: Root node of the parsed expression AST.
"""
tokens: List[str] = re.findall(
r"\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)", expr, flags=re.IGNORECASE
)

def wildcard_condition(pattern: str) -> str:
pattern = pattern.strip('"').strip("'")
stripped = pattern.strip("*")
tokens = self._tokenize(expr)
result, _ = self._parse_or_expression(tokens, 0)
return result

def _tokenize(self, expr: str) -> List[str]:
"""Tokenize the expression into logical operators, keywords, and parentheses."""
return re.findall(r"\bAND\b|\bOR\b|\bNOT\b|[^\s()]+|\(|\)", expr, flags=re.IGNORECASE)

def _parse_or_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]:
"""Parse OR expression (lowest precedence)."""
left, pos = self._parse_and_expression(tokens, pos)

while pos < len(tokens) and tokens[pos].upper() == "OR":
pos += 1 # Skip OR token
right, pos = self._parse_and_expression(tokens, pos)
left = _OrNode(left, right)

return left, pos

def _parse_and_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]:
"""Parse AND expression (medium precedence)."""
left, pos = self._parse_not_expression(tokens, pos)

while pos < len(tokens) and tokens[pos].upper() == "AND":
pos += 1 # Skip AND token
right, pos = self._parse_not_expression(tokens, pos)
left = _AndNode(left, right)

return left, pos

def _parse_not_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]:
"""Parse NOT expression (highest precedence)."""
if pos < len(tokens) and tokens[pos].upper() == "NOT":
pos += 1 # Skip NOT token
operand, pos = self._parse_primary_expression(tokens, pos)
return _NotNode(operand), pos
else:
return self._parse_primary_expression(tokens, pos)

def _parse_primary_expression(self, tokens: List[str], pos: int) -> tuple[_ExpressionNode, int]:
"""Parse primary expression (parentheses or pattern)."""
if pos >= len(tokens):
raise ValueError("Unexpected end of expression")

token = tokens[pos]

if token == "(":
pos += 1 # Skip opening parenthesis
expr, pos = self._parse_or_expression(tokens, pos)
if pos >= len(tokens) or tokens[pos] != ")":
raise ValueError("Missing closing parenthesis")
pos += 1 # Skip closing parenthesis
return expr, pos
elif token == ")":
raise ValueError("Unexpected closing parenthesis")
else:
# Pattern token
return _PatternNode(token), pos + 1

if pattern.startswith("*") and pattern.endswith("*"):
return f"{repr(stripped)} in k"
elif pattern.startswith("*"):
return f"k.endswith({repr(stripped)})"
elif pattern.endswith("*"):
return f"k.startswith({repr(stripped)})"
else:
return f"k == {repr(pattern)}"

def convert_token(token: str) -> str:
upper = token.upper()
if upper == "AND":
return "and"
elif upper == "OR":
return "or"
elif upper == "NOT":
return "not"
elif token in ("(", ")"):
return token
else:
return f"any({wildcard_condition(token)} for k in keywords)"

converted_tokens = [convert_token(tok) for tok in tokens]
return " ".join(converted_tokens)
def _convert_expression(self, expr: str) -> str:
"""
Legacy method for backward compatibility.
This method is no longer used but kept to avoid breaking changes.
"""
# This method is deprecated and should not be used
# It's kept only for backward compatibility
return expr


def _list_all_hub_models(hub_name: str, sm_client: Session) -> Iterator[HubContent]:
Expand Down
Loading