import json import logging from typing import List, Optional, Type, TypeVar, cast from opensearchpy import OpenSearch from opensearchorm.model import BaseModel from opensearchorm.query import ModelQuery, Expr from opensearchorm.aggs import * Model = TypeVar('Model', bound=BaseModel) class SearchSession: def __init__(self, host: str, user: str, password: str, **kwargs) -> None: self.client = OpenSearch( hosts=[ host, ], http_auth=(user, password), http_compress=True, use_ssl=True, verify_certs=True, ssl_assert_hostname=False, ssl_show_warn=False, **kwargs, ) def __enter__(self): return self def __exit__(self, type, value, traceback): self.client.close() def select(self, model: Type[Model]): return QueryExecutor(model, self) def search(self, **kwargs): return self.client.search(**kwargs) def count(self, **kwargs): return self.client.count(**kwargs) class QueryExecutor: def __init__(self, model_cls: Type[Model], session: SearchSession): self.__query = ModelQuery(model_cls) self.__model_cls = model_cls self.__include_fields = [] self.__limit: Optional[int] = None self.__offset: Optional[int] = None self.__session = session def filter(self, *args: Expr, **kwargs): self.__query.filter(*args, **kwargs) return self def union(self, *args: Expr, **kwargs): self.__query.union(*args, **kwargs) return self def exclude(self, *args: Expr, **kwargs): self.__query.exclude(*args, **kwargs) return self def limit(self, limit: int): self.__limit = limit return self def offset(self, offset: int): self.__offset = offset return self def values(self, fields: List[str]): self.__include_fields = fields return self def fetch(self, **kwargs): body = { 'query': self.__query.compile(), } logging.debug('query:\n%s', json.dumps(body)) model = self.__model_cls assert model and model.__index__, 'model has no index' resp = self.__session.search( body=body, index=model.__index__, size=self.__limit, from_=self.__offset, _source_includes=self.__include_fields or model.default_fields(), **kwargs, ) hits = resp['hits']['hits'] logging.debug('raw result: %s', hits) if self.__include_fields: return [hit['_source'] for hit in hits] else: return [model.parse_obj(hit['_source']) for hit in hits] def scroll(self, **kwargs): ... def unique_count(self, field: str, is_text: bool = False, **kwargs) -> int: resp = self.aggregate(Cardinality(field, is_text), **kwargs) return cast(int, resp) def sum(self, field: str, **kwargs) -> float: resp = self.aggregate(Sum(field), **kwargs) return cast(int, resp) def count(self, **kwargs) -> int: body = { 'query': self.__query.compile(), } logging.debug('query:\n%s', json.dumps(body)) model = self.__model_cls assert model and model.__index__, 'model has no index' resp = self.__session.count( body=body, index=model.__index__, **kwargs, ) return resp['count'] def aggregate(self, aggs: Aggregation, **kwargs): body = { 'query': self.__query.compile(), 'aggs': aggs.compile(depth=1), } logging.debug('query:\n%s', json.dumps(body)) model = self.__model_cls assert model and model.__index__, 'model has no index' resp = self.__session.search( body=body, index=model.__index__, size=0, **kwargs, ) data = resp['aggregations'] return parse_aggregations(data, depth=1) def parse_aggregations(data: dict, depth: int = 1): level = data.get(str(depth), None) if level is None: return if 'buckets' in level: result = {} buckets = level['buckets'] for b in buckets: key = b['key'] count = b['doc_count'] children = parse_aggregations(b, depth + 1) result[key] = children if children else count return result else: value = level['value'] return value