feat: query builder

This commit is contained in:
yim7
2022-09-21 22:05:04 +08:00
parent 438aec7699
commit 70a6e2bd78
8 changed files with 825 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from .session import SearchSession
from .model import BaseModel
from .query import *

10
opensearchorm/model.py Normal file
View File

@@ -0,0 +1,10 @@
from typing import ClassVar, Optional
from pydantic import BaseModel as RawBaseModel
class BaseModel(RawBaseModel):
__index__: ClassVar[Optional[str]] = None
@classmethod
def default_fields(cls):
return list(cls.__fields__.keys())

174
opensearchorm/query.py Normal file
View File

@@ -0,0 +1,174 @@
import abc
from enum import Enum
import logging
from typing import List, Optional, Type, TypeVar, Union
from opensearchorm.model import BaseModel
Model = TypeVar('Model', bound=BaseModel)
class Expr(abc.ABC):
@abc.abstractmethod
def compile(self) -> dict:
...
class Contains(Expr):
def __init__(self, field: str, values: list):
self.field = field
self.values = values
def compile(self):
return {
'bool': {
'should': [MatchPhrase(self.field, v).compile() for v in self.values],
'minimum_should_match': 1,
}
}
class Range(Expr):
def __init__(self, field: str, value: Union[str, int], operator: str):
self.field = field
self.op = operator
self.value = value
def compile(self):
return {
'range': {
self.field: {
self.op: self.value,
}
}
}
class MatchPhrase(Expr):
def __init__(self, field: str, value: str):
self.field = field
self.value = value
def compile(self):
return {
'match_phrase': {
self.field: self.value,
}
}
class Prefix(Expr):
def __init__(self, field: str, value: str):
self.field = field
self.value = value
def compile(self):
return {
'prefix': {
self.field: self.value,
}
}
class Wildcard(Expr):
def __init__(self, field: str, value: str):
self.field = field
self.value = value
def compile(self):
return {
'wildcard': {
self.field: self.value,
}
}
class RegExp(Expr):
def __init__(self, field: str, value: str):
self.field = field
self.value = value
def compile(self):
return {
'regexp': {
self.field: self.value,
}
}
class Operator(Enum):
PREFIX = '__prefix'
REGEXP = '__regexp'
CONTAINS = '__contains'
GTE = '__gte'
GT = '__gt'
LTE = '__lte'
LT = '__lt'
class ModelQuery(Expr):
def __init__(self, model_cls: Type[Model]):
self.__model_cls = model_cls
self.__filter: List[Expr] = []
self.__exclude: List[Expr] = []
self.__union: List[Expr] = []
def compile(self):
return {
'bool': {
'must_not': [e.compile() for e in self.__exclude],
'should': [e.compile() for e in self.__union],
'filter': [e.compile() for e in self.__filter],
}
}
def parse_clause(self, raw_field: str, value) -> Expr:
field = raw_field
for op in Operator:
suffix: str = op.value
if raw_field.endswith(suffix):
field = raw_field.removesuffix(suffix)
logging.debug('parse field: %s, raw: %s', field, raw_field)
if op == Operator.CONTAINS:
return Contains(field, value)
elif op == Operator.PREFIX:
return Prefix(field, value)
elif op == Operator.REGEXP:
return RegExp(field, value)
elif op in (Operator.GTE, Operator.GT, Operator.LTE, Operator.LT):
op = suffix.lstrip('_')
return Range(field, value, op)
model = self.__model_cls
valid_fields = set(model.default_fields())
assert field in valid_fields, f'check field name: {field}'
return MatchPhrase(field, value)
def parse_clauses(self, **kwargs):
clauses = []
for k, v in kwargs.items():
cond = self.parse_clause(k, v)
clauses.append(cond)
return clauses
def filter(self, *args: Expr, **kwargs):
conditions = self.parse_clauses(**kwargs)
self.__filter.extend(args)
self.__filter.extend(conditions)
return self
def union(self, *args: Expr, **kwargs):
conditions = self.parse_clauses(**kwargs)
self.__union.extend(args)
self.__union.extend(conditions)
return self
def exclude(self, *args: Expr, **kwargs):
conditions = self.parse_clauses(**kwargs)
self.__exclude.extend(args)
self.__exclude.extend(conditions)
return self

101
opensearchorm/session.py Normal file
View File

@@ -0,0 +1,101 @@
import logging
from typing import List, Optional, Type, TypeVar
from opensearchpy import OpenSearch
from opensearchorm.model import BaseModel
from opensearchorm.query import ModelQuery, Expr
Model = TypeVar('Model', bound=BaseModel)
class SearchSession:
def __init__(self, host: str, user: str, password: str, **kwargs) -> None:
self.client = OpenSearch(
hosts=[
host,
],
http_auth=(user, password),
http_compress=True,
use_ssl=True,
verify_certs=True,
ssl_assert_hostname=False,
ssl_show_warn=False,
**kwargs,
)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.client.close()
def select(self, model: Type[Model]):
return QueryExecutor(model, self)
def search(self, **kwargs):
return self.client.search(**kwargs)
class QueryExecutor:
def __init__(self, model_cls: Type[Model], session):
self.__query = ModelQuery(model_cls)
self.__model_cls = model_cls
self._include_fields = []
self._limit: Optional[int] = None
self._offset: Optional[int] = None
self._session = session
def filter(self, *args: Expr, **kwargs):
self.__query.filter(*args, **kwargs)
return self
def union(self, *args: Expr, **kwargs):
self.__query.union(*args, **kwargs)
return self
def exclude(self, *args: Expr, **kwargs):
self.__query.exclude(*args, **kwargs)
return self
def limit(self, limit: int):
self._limit = limit
return self
def offset(self, offset: int):
self._offset = offset
return self
def values(self, fields: List[str]):
self._include_fields = fields
return self
def fetch(self):
body = {
'query': self.__query.compile(),
}
logging.debug('query:\n%s', body)
params = {
'format': 'json',
'request_timeout': 300,
}
model = self.__model_cls
assert model and model.__index__, 'model has no index'
data = self._session.search(
body=body,
params=params,
index=model.__index__,
size=self._limit,
from_=self._offset,
_source_includes=self._include_fields or model.default_fields(),
)
hits = data['hits']['hits']
logging.debug('raw result: %s', hits)
if self._include_fields:
return [hit['_source'] for hit in hits]
else:
return [model.parse_obj(hit['_source']) for hit in hits]