mirror of
https://github.com/d0zingcat/opensearch-orm.git
synced 2026-05-14 15:10:12 +00:00
feat: aggregations
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .session import SearchSession
|
||||
from .model import BaseModel
|
||||
from .query import *
|
||||
from .aggs import *
|
||||
|
||||
65
opensearchorm/aggs.py
Normal file
65
opensearchorm/aggs.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import abc
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Aggregation(abc.ABC):
|
||||
def __init__(self, field: str, is_text: bool = False) -> None:
|
||||
self.field = f'{field}.keyword' if is_text else field
|
||||
|
||||
@abc.abstractmethod
|
||||
def compile(self, depth: int = 1):
|
||||
...
|
||||
|
||||
|
||||
class MetricAggregation(Aggregation):
|
||||
...
|
||||
|
||||
|
||||
class BucketAggregation(Aggregation):
|
||||
@abc.abstractmethod
|
||||
def nested(self, child: Aggregation):
|
||||
...
|
||||
|
||||
|
||||
class Terms(BucketAggregation):
|
||||
def __init__(self, field: str, is_text: bool = False, max_buckets: int = 100) -> None:
|
||||
super().__init__(field, is_text)
|
||||
self.max_buckets = max_buckets
|
||||
self.child: Optional[Aggregation] = None
|
||||
|
||||
def compile(self, depth: int = 1):
|
||||
return {
|
||||
depth: {
|
||||
'terms': {
|
||||
'field': self.field,
|
||||
'size': self.max_buckets,
|
||||
},
|
||||
'aggs': self.child.compile(depth + 1) if self.child else {},
|
||||
}
|
||||
}
|
||||
|
||||
def nested(self, child: Aggregation):
|
||||
self.child = child
|
||||
return self
|
||||
|
||||
|
||||
class Cardinality(MetricAggregation):
|
||||
def compile(self, depth: int = 1):
|
||||
return {
|
||||
depth: {
|
||||
'cardinality': {
|
||||
'field': self.field,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class Sum(MetricAggregation):
|
||||
def compile(self, depth: int = 1):
|
||||
return {
|
||||
depth: {
|
||||
'sum': {
|
||||
'field': self.field,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -119,6 +119,7 @@ class ModelQuery(Expr):
|
||||
'must_not': [e.compile() for e in self.__exclude],
|
||||
'should': [e.compile() for e in self.__union],
|
||||
'filter': [e.compile() for e in self.__filter],
|
||||
'minimum_should_match': 1 if self.__union else 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Type, TypeVar
|
||||
from typing import List, Optional, Type, TypeVar, cast
|
||||
|
||||
from opensearchpy import OpenSearch
|
||||
|
||||
from opensearchorm.model import BaseModel
|
||||
from opensearchorm.query import ModelQuery, Expr
|
||||
from opensearchorm.aggs import *
|
||||
|
||||
Model = TypeVar('Model', bound=BaseModel)
|
||||
|
||||
@@ -36,15 +38,18 @@ class SearchSession:
|
||||
def search(self, **kwargs):
|
||||
return self.client.search(**kwargs)
|
||||
|
||||
def count(self, **kwargs):
|
||||
return self.client.count(**kwargs)
|
||||
|
||||
|
||||
class QueryExecutor:
|
||||
def __init__(self, model_cls: Type[Model], session):
|
||||
def __init__(self, model_cls: Type[Model], session: SearchSession):
|
||||
self.__query = ModelQuery(model_cls)
|
||||
self.__model_cls = model_cls
|
||||
self._include_fields = []
|
||||
self._limit: Optional[int] = None
|
||||
self._offset: Optional[int] = None
|
||||
self._session = session
|
||||
self.__include_fields = []
|
||||
self.__limit: Optional[int] = None
|
||||
self.__offset: Optional[int] = None
|
||||
self.__session = session
|
||||
|
||||
def filter(self, *args: Expr, **kwargs):
|
||||
self.__query.filter(*args, **kwargs)
|
||||
@@ -59,43 +64,104 @@ class QueryExecutor:
|
||||
return self
|
||||
|
||||
def limit(self, limit: int):
|
||||
self._limit = limit
|
||||
self.__limit = limit
|
||||
return self
|
||||
|
||||
def offset(self, offset: int):
|
||||
self._offset = offset
|
||||
self.__offset = offset
|
||||
return self
|
||||
|
||||
def values(self, fields: List[str]):
|
||||
self._include_fields = fields
|
||||
self.__include_fields = fields
|
||||
return self
|
||||
|
||||
def fetch(self):
|
||||
def fetch(self, **kwargs):
|
||||
body = {
|
||||
'query': self.__query.compile(),
|
||||
}
|
||||
|
||||
logging.debug('query:\n%s', body)
|
||||
params = {
|
||||
'format': 'json',
|
||||
'request_timeout': 300,
|
||||
}
|
||||
logging.debug('query:\n%s', json.dumps(body))
|
||||
|
||||
model = self.__model_cls
|
||||
assert model and model.__index__, 'model has no index'
|
||||
|
||||
data = self._session.search(
|
||||
resp = self.__session.search(
|
||||
body=body,
|
||||
params=params,
|
||||
index=model.__index__,
|
||||
size=self._limit,
|
||||
from_=self._offset,
|
||||
_source_includes=self._include_fields or model.default_fields(),
|
||||
size=self.__limit,
|
||||
from_=self.__offset,
|
||||
_source_includes=self.__include_fields or model.default_fields(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hits = data['hits']['hits']
|
||||
hits = resp['hits']['hits']
|
||||
logging.debug('raw result: %s', hits)
|
||||
if self._include_fields:
|
||||
if self.__include_fields:
|
||||
return [hit['_source'] for hit in hits]
|
||||
else:
|
||||
return [model.parse_obj(hit['_source']) for hit in hits]
|
||||
|
||||
def scroll(self, **kwargs):
|
||||
...
|
||||
|
||||
def unique_count(self, field: str, is_text: bool = False, **kwargs) -> int:
|
||||
resp = self.aggregate(Cardinality(field, is_text), **kwargs)
|
||||
return cast(int, resp)
|
||||
|
||||
def sum(self, field: str, **kwargs) -> float:
|
||||
resp = self.aggregate(Sum(field), **kwargs)
|
||||
return cast(int, resp)
|
||||
|
||||
def count(self, **kwargs) -> int:
|
||||
body = {
|
||||
'query': self.__query.compile(),
|
||||
}
|
||||
logging.debug('query:\n%s', json.dumps(body))
|
||||
|
||||
model = self.__model_cls
|
||||
assert model and model.__index__, 'model has no index'
|
||||
|
||||
resp = self.__session.count(
|
||||
body=body,
|
||||
index=model.__index__,
|
||||
**kwargs,
|
||||
)
|
||||
return resp['count']
|
||||
|
||||
def aggregate(self, aggs: Aggregation, **kwargs):
|
||||
body = {
|
||||
'query': self.__query.compile(),
|
||||
'aggs': aggs.compile(depth=1),
|
||||
}
|
||||
logging.debug('query:\n%s', json.dumps(body))
|
||||
|
||||
model = self.__model_cls
|
||||
assert model and model.__index__, 'model has no index'
|
||||
|
||||
resp = self.__session.search(
|
||||
body=body,
|
||||
index=model.__index__,
|
||||
size=0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
data = resp['aggregations']
|
||||
return parse_aggregations(data, depth=1)
|
||||
|
||||
|
||||
def parse_aggregations(data: dict, depth: int = 1):
|
||||
level = data.get(str(depth), None)
|
||||
if level is None:
|
||||
return
|
||||
|
||||
if 'buckets' in level:
|
||||
result = {}
|
||||
buckets = level['buckets']
|
||||
for b in buckets:
|
||||
key = b['key']
|
||||
count = b['doc_count']
|
||||
children = parse_aggregations(b, depth + 1)
|
||||
result[key] = children if children else count
|
||||
return result
|
||||
else:
|
||||
value = level['value']
|
||||
return value
|
||||
|
||||
Reference in New Issue
Block a user