feat: aggs remove is_text field

This commit is contained in:
yim7
2022-09-22 15:29:50 +08:00
parent 7ab7581625
commit cb6ab84d51
5 changed files with 573 additions and 57 deletions

View File

@@ -3,8 +3,8 @@ from typing import Optional
class Aggregation(abc.ABC):
def __init__(self, field: str, is_text: bool = False) -> None:
self.field = f'{field}.keyword' if is_text else field
def __init__(self, field: str) -> None:
self.field = field
@abc.abstractmethod
def compile(self, depth: int = 1):
@@ -22,8 +22,8 @@ class BucketAggregation(Aggregation):
class Terms(BucketAggregation):
def __init__(self, field: str, is_text: bool = False, max_buckets: int = 100) -> None:
super().__init__(field, is_text)
def __init__(self, field: str, max_buckets: int = 100) -> None:
super().__init__(field)
self.max_buckets = max_buckets
self.child: Optional[Aggregation] = None

View File

@@ -1,28 +1,37 @@
import json
import logging
from typing import List, Optional, Type, TypeVar, cast
from typing import List, Union, Optional, Type, TypeVar, cast
from opensearchpy import OpenSearch
from opensearchorm.model import BaseModel
from opensearchorm.query import ModelQuery, Expr
from opensearchorm.aggs import Aggregation, Sum, Cardinality
from opensearchorm.utils import parse_aggregations
Host = Union[str, dict]
Model = TypeVar('Model', bound=BaseModel)
class SearchSession:
def __init__(self, host: str, user: str, password: str, **kwargs) -> None:
def __init__(self, hosts: Union[Host, List[Host]], user: str, password: str, **kwargs) -> None:
"""
:arg hosts: list of nodes, or a single node, we should connect to.
Node should be a dictionary ({"host": "localhost", "port": 9200}),
the entire dictionary will be passed to the :class:`~opensearchpy.Connection`
class as kwargs, or a string in the format of ``host[:port]`` which will be
translated to a dictionary automatically.
:arg user: http auth username
:arg password: http auth password
:arg kwargs: any additional arguments will be passed on to the opensearch-py call
"""
self.client = OpenSearch(
hosts=[
host,
],
hosts=hosts,
http_auth=(user, password),
http_compress=True,
use_ssl=True,
verify_certs=True,
ssl_assert_hostname=False,
ssl_show_warn=False,
**kwargs,
)
@@ -76,6 +85,10 @@ class QueryExecutor:
return self
def fetch(self, **kwargs):
"""
:arg kwargs: any additional arguments will be passed on to the opensearch-py call
"""
body = {
'query': self.__query.compile(),
}
@@ -103,8 +116,32 @@ class QueryExecutor:
def scroll(self, **kwargs):
...
def unique_count(self, field: str, is_text: bool = False, **kwargs) -> int:
resp = self.aggregate(Cardinality(field, is_text), **kwargs)
def aggregate(self, aggs: Aggregation, **kwargs):
"""
:arg kwargs: any additional arguments will be passed on to the opensearch-py call
"""
body = {
'query': self.__query.compile(),
'aggs': aggs.compile(depth=1),
}
logging.debug('query:\n%s', json.dumps(body))
model = self.__model_cls
assert model and model.__index__, 'model has no index'
resp = self.__session.search(
body=body,
index=model.__index__,
size=0,
**kwargs,
)
data = resp['aggregations']
return parse_aggregations(data, depth=1)
def unique_count(self, field: str, **kwargs) -> int:
resp = self.aggregate(Cardinality(field), **kwargs)
return cast(int, resp)
def sum(self, field: str, **kwargs) -> float:
@@ -126,42 +163,3 @@ class QueryExecutor:
**kwargs,
)
return resp['count']
def aggregate(self, aggs: Aggregation, **kwargs):
body = {
'query': self.__query.compile(),
'aggs': aggs.compile(depth=1),
}
logging.debug('query:\n%s', json.dumps(body))
model = self.__model_cls
assert model and model.__index__, 'model has no index'
resp = self.__session.search(
body=body,
index=model.__index__,
size=0,
**kwargs,
)
data = resp['aggregations']
return parse_aggregations(data, depth=1)
def parse_aggregations(data: dict, depth: int = 1):
level = data.get(str(depth), None)
if level is None:
return
if 'buckets' in level:
result = {}
buckets = level['buckets']
for b in buckets:
key = b['key']
count = b['doc_count']
children = parse_aggregations(b, depth + 1)
result[key] = children if children else count
return result
else:
value = level['value']
return value

17
opensearchorm/utils.py Normal file
View File

@@ -0,0 +1,17 @@
def parse_aggregations(data: dict, depth: int = 1):
level = data.get(str(depth), None)
if level is None:
return
if 'buckets' in level:
result = {}
buckets = level['buckets']
for b in buckets:
key = b['key']
count = b['doc_count']
children = parse_aggregations(b, depth + 1)
result[key] = children if children else count
return result
else:
value = level['value']
return value