diff --git a/opensearchorm/__init__.py b/opensearchorm/__init__.py index a8199cd..d310498 100644 --- a/opensearchorm/__init__.py +++ b/opensearchorm/__init__.py @@ -1,3 +1,4 @@ from .session import SearchSession from .model import BaseModel from .query import * +from .aggs import * diff --git a/opensearchorm/aggs.py b/opensearchorm/aggs.py new file mode 100644 index 0000000..d884a75 --- /dev/null +++ b/opensearchorm/aggs.py @@ -0,0 +1,65 @@ +import abc +from typing import Optional + + +class Aggregation(abc.ABC): + def __init__(self, field: str, is_text: bool = False) -> None: + self.field = f'{field}.keyword' if is_text else field + + @abc.abstractmethod + def compile(self, depth: int = 1): + ... + + +class MetricAggregation(Aggregation): + ... + + +class BucketAggregation(Aggregation): + @abc.abstractmethod + def nested(self, child: Aggregation): + ... + + +class Terms(BucketAggregation): + def __init__(self, field: str, is_text: bool = False, max_buckets: int = 100) -> None: + super().__init__(field, is_text) + self.max_buckets = max_buckets + self.child: Optional[Aggregation] = None + + def compile(self, depth: int = 1): + return { + depth: { + 'terms': { + 'field': self.field, + 'size': self.max_buckets, + }, + 'aggs': self.child.compile(depth + 1) if self.child else {}, + } + } + + def nested(self, child: Aggregation): + self.child = child + return self + + +class Cardinality(MetricAggregation): + def compile(self, depth: int = 1): + return { + depth: { + 'cardinality': { + 'field': self.field, + } + } + } + + +class Sum(MetricAggregation): + def compile(self, depth: int = 1): + return { + depth: { + 'sum': { + 'field': self.field, + } + } + } diff --git a/opensearchorm/query.py b/opensearchorm/query.py index 988d6fc..97796bb 100644 --- a/opensearchorm/query.py +++ b/opensearchorm/query.py @@ -119,6 +119,7 @@ class ModelQuery(Expr): 'must_not': [e.compile() for e in self.__exclude], 'should': [e.compile() for e in self.__union], 'filter': [e.compile() for e in self.__filter], + 'minimum_should_match': 1 if self.__union else 0, } } diff --git a/opensearchorm/session.py b/opensearchorm/session.py index aa3af3f..dae7058 100644 --- a/opensearchorm/session.py +++ b/opensearchorm/session.py @@ -1,10 +1,12 @@ +import json import logging -from typing import List, Optional, Type, TypeVar +from typing import List, Optional, Type, TypeVar, cast from opensearchpy import OpenSearch from opensearchorm.model import BaseModel from opensearchorm.query import ModelQuery, Expr +from opensearchorm.aggs import * Model = TypeVar('Model', bound=BaseModel) @@ -36,15 +38,18 @@ class SearchSession: def search(self, **kwargs): return self.client.search(**kwargs) + def count(self, **kwargs): + return self.client.count(**kwargs) + class QueryExecutor: - def __init__(self, model_cls: Type[Model], session): + def __init__(self, model_cls: Type[Model], session: SearchSession): self.__query = ModelQuery(model_cls) self.__model_cls = model_cls - self._include_fields = [] - self._limit: Optional[int] = None - self._offset: Optional[int] = None - self._session = session + self.__include_fields = [] + self.__limit: Optional[int] = None + self.__offset: Optional[int] = None + self.__session = session def filter(self, *args: Expr, **kwargs): self.__query.filter(*args, **kwargs) @@ -59,43 +64,104 @@ class QueryExecutor: return self def limit(self, limit: int): - self._limit = limit + self.__limit = limit return self def offset(self, offset: int): - self._offset = offset + self.__offset = offset return self def values(self, fields: List[str]): - self._include_fields = fields + self.__include_fields = fields return self - def fetch(self): + def fetch(self, **kwargs): body = { 'query': self.__query.compile(), } - - logging.debug('query:\n%s', body) - params = { - 'format': 'json', - 'request_timeout': 300, - } + logging.debug('query:\n%s', json.dumps(body)) model = self.__model_cls assert model and model.__index__, 'model has no index' - data = self._session.search( + resp = self.__session.search( body=body, - params=params, index=model.__index__, - size=self._limit, - from_=self._offset, - _source_includes=self._include_fields or model.default_fields(), + size=self.__limit, + from_=self.__offset, + _source_includes=self.__include_fields or model.default_fields(), + **kwargs, ) - hits = data['hits']['hits'] + hits = resp['hits']['hits'] logging.debug('raw result: %s', hits) - if self._include_fields: + if self.__include_fields: return [hit['_source'] for hit in hits] else: return [model.parse_obj(hit['_source']) for hit in hits] + + def scroll(self, **kwargs): + ... + + def unique_count(self, field: str, is_text: bool = False, **kwargs) -> int: + resp = self.aggregate(Cardinality(field, is_text), **kwargs) + return cast(int, resp) + + def sum(self, field: str, **kwargs) -> float: + resp = self.aggregate(Sum(field), **kwargs) + return cast(int, resp) + + def count(self, **kwargs) -> int: + body = { + 'query': self.__query.compile(), + } + logging.debug('query:\n%s', json.dumps(body)) + + model = self.__model_cls + assert model and model.__index__, 'model has no index' + + resp = self.__session.count( + body=body, + index=model.__index__, + **kwargs, + ) + return resp['count'] + + def aggregate(self, aggs: Aggregation, **kwargs): + body = { + 'query': self.__query.compile(), + 'aggs': aggs.compile(depth=1), + } + logging.debug('query:\n%s', json.dumps(body)) + + model = self.__model_cls + assert model and model.__index__, 'model has no index' + + resp = self.__session.search( + body=body, + index=model.__index__, + size=0, + **kwargs, + ) + + data = resp['aggregations'] + return parse_aggregations(data, depth=1) + + +def parse_aggregations(data: dict, depth: int = 1): + level = data.get(str(depth), None) + if level is None: + return + + if 'buckets' in level: + result = {} + buckets = level['buckets'] + for b in buckets: + key = b['key'] + count = b['doc_count'] + children = parse_aggregations(b, depth + 1) + result[key] = children if children else count + return result + else: + value = level['value'] + return value