Files
opensearch-orm/opensearchorm/session.py
2022-09-22 14:04:34 +08:00

168 lines
4.5 KiB
Python

import json
import logging
from typing import List, Optional, Type, TypeVar, cast
from opensearchpy import OpenSearch
from opensearchorm.model import BaseModel
from opensearchorm.query import ModelQuery, Expr
from opensearchorm.aggs import Aggregation, Sum, Cardinality
Model = TypeVar('Model', bound=BaseModel)
class SearchSession:
def __init__(self, host: str, user: str, password: str, **kwargs) -> None:
self.client = OpenSearch(
hosts=[
host,
],
http_auth=(user, password),
http_compress=True,
use_ssl=True,
verify_certs=True,
ssl_assert_hostname=False,
ssl_show_warn=False,
**kwargs,
)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.client.close()
def select(self, model: Type[Model]):
return QueryExecutor(model, self)
def search(self, **kwargs):
return self.client.search(**kwargs)
def count(self, **kwargs):
return self.client.count(**kwargs)
class QueryExecutor:
def __init__(self, model_cls: Type[Model], session: SearchSession):
self.__query = ModelQuery(model_cls)
self.__model_cls = model_cls
self.__include_fields = []
self.__limit: Optional[int] = None
self.__offset: Optional[int] = None
self.__session = session
def filter(self, *args: Expr, **kwargs):
self.__query.filter(*args, **kwargs)
return self
def union(self, *args: Expr, **kwargs):
self.__query.union(*args, **kwargs)
return self
def exclude(self, *args: Expr, **kwargs):
self.__query.exclude(*args, **kwargs)
return self
def limit(self, limit: int):
self.__limit = limit
return self
def offset(self, offset: int):
self.__offset = offset
return self
def values(self, fields: List[str]):
self.__include_fields = fields
return self
def fetch(self, **kwargs):
body = {
'query': self.__query.compile(),
}
logging.debug('query:\n%s', json.dumps(body))
model = self.__model_cls
assert model and model.__index__, 'model has no index'
resp = self.__session.search(
body=body,
index=model.__index__,
size=self.__limit,
from_=self.__offset,
_source_includes=self.__include_fields or model.default_fields(),
**kwargs,
)
hits = resp['hits']['hits']
logging.debug('raw result: %s', hits)
if self.__include_fields:
return [hit['_source'] for hit in hits]
else:
return [model.parse_obj(hit['_source']) for hit in hits]
def scroll(self, **kwargs):
...
def unique_count(self, field: str, is_text: bool = False, **kwargs) -> int:
resp = self.aggregate(Cardinality(field, is_text), **kwargs)
return cast(int, resp)
def sum(self, field: str, **kwargs) -> float:
resp = self.aggregate(Sum(field), **kwargs)
return cast(int, resp)
def count(self, **kwargs) -> int:
body = {
'query': self.__query.compile(),
}
logging.debug('query:\n%s', json.dumps(body))
model = self.__model_cls
assert model and model.__index__, 'model has no index'
resp = self.__session.count(
body=body,
index=model.__index__,
**kwargs,
)
return resp['count']
def aggregate(self, aggs: Aggregation, **kwargs):
body = {
'query': self.__query.compile(),
'aggs': aggs.compile(depth=1),
}
logging.debug('query:\n%s', json.dumps(body))
model = self.__model_cls
assert model and model.__index__, 'model has no index'
resp = self.__session.search(
body=body,
index=model.__index__,
size=0,
**kwargs,
)
data = resp['aggregations']
return parse_aggregations(data, depth=1)
def parse_aggregations(data: dict, depth: int = 1):
level = data.get(str(depth), None)
if level is None:
return
if 'buckets' in level:
result = {}
buckets = level['buckets']
for b in buckets:
key = b['key']
count = b['doc_count']
children = parse_aggregations(b, depth + 1)
result[key] = children if children else count
return result
else:
value = level['value']
return value