feat: refactor range and prefix query

This commit is contained in:
yim7
2022-09-22 22:25:23 +08:00
parent 34f6252213
commit 50462af958
3 changed files with 44 additions and 24 deletions

View File

@@ -1,7 +1,8 @@
import abc
from datetime import date, datetime
from enum import Enum
import logging
from typing import List, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union
from opensearchorm.model import BaseModel
@@ -15,31 +16,45 @@ class Expr(abc.ABC):
class Contains(Expr):
def __init__(self, field: str, values: list):
def __init__(self, field: str, values: list, min_match: int = 1):
self.field = field
self.values = values
self.min_match = min_match
def compile(self):
return {
'bool': {
'should': [MatchPhrase(self.field, v).compile() for v in self.values],
'minimum_should_match': 1,
'minimum_should_match': self.min_match,
}
}
class Range(Expr):
def __init__(self, field: str, value: Union[str, int], operator: str):
def __init__(self, field: str, interval: Tuple[Any, Any], *, left_open: bool = False, right_open: bool = False):
self.field = field
self.op = operator
self.value = value
self.interval = interval
self.left_open = left_open
self.right_open = right_open
def compile(self):
left, right = self.interval
if isinstance(left, (date, datetime)):
left = left.isoformat()
if isinstance(right, (date, datetime)):
right = right.isoformat()
range = {}
if left is not None:
op = 'gt' if self.left_open else 'gte'
range[op] = left
if right is not None:
op = 'lt' if self.right_open else 'lte'
range[op] = right
return {
'range': {
self.field: {
self.op: self.value,
}
self.field: range,
}
}
@@ -57,14 +72,14 @@ class MatchPhrase(Expr):
}
class Prefix(Expr):
class MatchPhrasePrefix(Expr):
def __init__(self, field: str, value: str):
self.field = field
self.value = value
def compile(self):
return {
'prefix': {
'match_phrase_prefix': {
self.field: self.value,
}
}
@@ -106,6 +121,17 @@ class Operator(Enum):
LT = '__lt'
OPERATOR_FUNCTIONS: Dict[Operator, Callable[[str, Any], Expr]] = {
Operator.CONTAINS: lambda field, value: Contains(field, value),
Operator.PREFIX: lambda field, value: MatchPhrasePrefix(field, value),
Operator.REGEXP: lambda field, value: RegExp(field, value),
Operator.GTE: lambda field, value: Range(field, (value, None)),
Operator.GT: lambda field, value: Range(field, (value, None), left_open=True),
Operator.LTE: lambda field, value: Range(field, (None, value)),
Operator.LT: lambda field, value: Range(field, (None, value), right_open=True),
}
class ModelQuery(Expr):
def __init__(self, model_cls: Type[Model]):
self.__model_cls = model_cls
@@ -137,18 +163,9 @@ class ModelQuery(Expr):
suffix: str = op.value
if raw_field.endswith(suffix):
field, _ = raw_field.rsplit(suffix)
self.check_valid_field(field)
logging.debug('parse field: %s, raw: %s', field, raw_field)
if op == Operator.CONTAINS:
return Contains(field, value)
if op == Operator.PREFIX:
return Prefix(field, value)
if op == Operator.REGEXP:
return RegExp(field, value)
if op in (Operator.GTE, Operator.GT, Operator.LTE, Operator.LT):
op = suffix.lstrip('_')
return Range(field, value, op)
self.check_valid_field(field)
return OPERATOR_FUNCTIONS[op](field, value)
self.check_valid_field(field)
return MatchPhrase(field, value)

View File

@@ -6,7 +6,7 @@ from opensearchpy import OpenSearch
from opensearchorm.model import BaseModel
from opensearchorm.query import ModelQuery, Expr
from opensearchorm.aggs import Aggregation, Sum, Cardinality
from opensearchorm.aggs import Aggregation, Sum, Cardinality, Terms
from opensearchorm.utils import parse_aggregations
Host = Union[str, dict]
@@ -163,3 +163,6 @@ class QueryExecutor:
**kwargs,
)
return resp['count']
def group_by(self, field: str, max_buckets: int = 100):
return self.aggregate(Terms(field, max_buckets))

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "opensearch-orm"
version = "0.1.3"
version = "0.1.4"
description = ""
authors = ["yim7 <yimchiu7@gmail.com>"]
readme = "README.md"