feat: aggregations

This commit is contained in:
yim7
2022-09-22 13:24:42 +08:00
parent 70a6e2bd78
commit 8a9c3974fc
4 changed files with 156 additions and 23 deletions

View File

@@ -1,3 +1,4 @@
from .session import SearchSession
from .model import BaseModel
from .query import *
from .aggs import *

65
opensearchorm/aggs.py Normal file
View File

@@ -0,0 +1,65 @@
import abc
from typing import Optional
class Aggregation(abc.ABC):
def __init__(self, field: str, is_text: bool = False) -> None:
self.field = f'{field}.keyword' if is_text else field
@abc.abstractmethod
def compile(self, depth: int = 1):
...
class MetricAggregation(Aggregation):
...
class BucketAggregation(Aggregation):
@abc.abstractmethod
def nested(self, child: Aggregation):
...
class Terms(BucketAggregation):
def __init__(self, field: str, is_text: bool = False, max_buckets: int = 100) -> None:
super().__init__(field, is_text)
self.max_buckets = max_buckets
self.child: Optional[Aggregation] = None
def compile(self, depth: int = 1):
return {
depth: {
'terms': {
'field': self.field,
'size': self.max_buckets,
},
'aggs': self.child.compile(depth + 1) if self.child else {},
}
}
def nested(self, child: Aggregation):
self.child = child
return self
class Cardinality(MetricAggregation):
def compile(self, depth: int = 1):
return {
depth: {
'cardinality': {
'field': self.field,
}
}
}
class Sum(MetricAggregation):
def compile(self, depth: int = 1):
return {
depth: {
'sum': {
'field': self.field,
}
}
}

View File

@@ -119,6 +119,7 @@ class ModelQuery(Expr):
'must_not': [e.compile() for e in self.__exclude],
'should': [e.compile() for e in self.__union],
'filter': [e.compile() for e in self.__filter],
'minimum_should_match': 1 if self.__union else 0,
}
}

View File

@@ -1,10 +1,12 @@
import json
import logging
from typing import List, Optional, Type, TypeVar
from typing import List, Optional, Type, TypeVar, cast
from opensearchpy import OpenSearch
from opensearchorm.model import BaseModel
from opensearchorm.query import ModelQuery, Expr
from opensearchorm.aggs import *
Model = TypeVar('Model', bound=BaseModel)
@@ -36,15 +38,18 @@ class SearchSession:
def search(self, **kwargs):
return self.client.search(**kwargs)
def count(self, **kwargs):
return self.client.count(**kwargs)
class QueryExecutor:
def __init__(self, model_cls: Type[Model], session):
def __init__(self, model_cls: Type[Model], session: SearchSession):
self.__query = ModelQuery(model_cls)
self.__model_cls = model_cls
self._include_fields = []
self._limit: Optional[int] = None
self._offset: Optional[int] = None
self._session = session
self.__include_fields = []
self.__limit: Optional[int] = None
self.__offset: Optional[int] = None
self.__session = session
def filter(self, *args: Expr, **kwargs):
self.__query.filter(*args, **kwargs)
@@ -59,43 +64,104 @@ class QueryExecutor:
return self
def limit(self, limit: int):
self._limit = limit
self.__limit = limit
return self
def offset(self, offset: int):
self._offset = offset
self.__offset = offset
return self
def values(self, fields: List[str]):
self._include_fields = fields
self.__include_fields = fields
return self
def fetch(self):
def fetch(self, **kwargs):
body = {
'query': self.__query.compile(),
}
logging.debug('query:\n%s', body)
params = {
'format': 'json',
'request_timeout': 300,
}
logging.debug('query:\n%s', json.dumps(body))
model = self.__model_cls
assert model and model.__index__, 'model has no index'
data = self._session.search(
resp = self.__session.search(
body=body,
params=params,
index=model.__index__,
size=self._limit,
from_=self._offset,
_source_includes=self._include_fields or model.default_fields(),
size=self.__limit,
from_=self.__offset,
_source_includes=self.__include_fields or model.default_fields(),
**kwargs,
)
hits = data['hits']['hits']
hits = resp['hits']['hits']
logging.debug('raw result: %s', hits)
if self._include_fields:
if self.__include_fields:
return [hit['_source'] for hit in hits]
else:
return [model.parse_obj(hit['_source']) for hit in hits]
def scroll(self, **kwargs):
...
def unique_count(self, field: str, is_text: bool = False, **kwargs) -> int:
resp = self.aggregate(Cardinality(field, is_text), **kwargs)
return cast(int, resp)
def sum(self, field: str, **kwargs) -> float:
resp = self.aggregate(Sum(field), **kwargs)
return cast(int, resp)
def count(self, **kwargs) -> int:
body = {
'query': self.__query.compile(),
}
logging.debug('query:\n%s', json.dumps(body))
model = self.__model_cls
assert model and model.__index__, 'model has no index'
resp = self.__session.count(
body=body,
index=model.__index__,
**kwargs,
)
return resp['count']
def aggregate(self, aggs: Aggregation, **kwargs):
body = {
'query': self.__query.compile(),
'aggs': aggs.compile(depth=1),
}
logging.debug('query:\n%s', json.dumps(body))
model = self.__model_cls
assert model and model.__index__, 'model has no index'
resp = self.__session.search(
body=body,
index=model.__index__,
size=0,
**kwargs,
)
data = resp['aggregations']
return parse_aggregations(data, depth=1)
def parse_aggregations(data: dict, depth: int = 1):
level = data.get(str(depth), None)
if level is None:
return
if 'buckets' in level:
result = {}
buckets = level['buckets']
for b in buckets:
key = b['key']
count = b['doc_count']
children = parse_aggregations(b, depth + 1)
result[key] = children if children else count
return result
else:
value = level['value']
return value