diff --git a/opensearchorm/session.py b/opensearchorm/session.py index fc95997..d9e1fd2 100644 --- a/opensearchorm/session.py +++ b/opensearchorm/session.py @@ -47,6 +47,13 @@ class SearchSession: def search(self, **kwargs): return self.client.search(**kwargs) + def scroll(self, scroll_id, lifetime): + body = dict( + scroll_id=scroll_id, + scroll=lifetime, + ) + return self.client.scroll(body=body) + def count(self, **kwargs): return self.client.count(**kwargs) @@ -79,7 +86,7 @@ class QueryExecutor(Generic[Model]): self.__offset = offset return self - def fetch_fields(self, fields: List[str], **kwargs): + def _search(self, fields: List[str], **kwargs): """ :arg fields: include source fields @@ -103,6 +110,17 @@ class QueryExecutor(Generic[Model]): **kwargs, ) + return resp + + def fetch_fields(self, fields: List[str], **kwargs): + """ + :arg fields: include source fields + + :arg kwargs: any additional arguments will be passed on to the opensearch-py call + """ + + resp = self._search(fields, **kwargs) + hits = resp['hits']['hits'] logging.debug('raw result: %s', hits) return [hit['_source'] for hit in hits] @@ -112,11 +130,30 @@ class QueryExecutor(Generic[Model]): :arg kwargs: any additional arguments will be passed on to the opensearch-py call """ model = self.__model_cls - hits = self.fetch_fields(model.default_fields()) + hits = self.fetch_fields(model.default_fields(), **kwargs) return [model.parse_obj(hit) for hit in hits] - def scroll(self, **kwargs): - ... + def scroll(self, lifetime, **kwargs): + model = self.__model_cls + + resp = self._search(model.default_fields(), scroll=lifetime, **kwargs) + + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + logging.debug('raw result: %s', hits) + hits = [hit['_source'] for hit in hits] + data = [model.parse_obj(hit) for hit in hits] + + yield data + + while scroll_id and data: + resp = self.__session.scroll(scroll_id, lifetime) + scroll_id = resp['_scroll_id'] + hits = resp['hits']['hits'] + logging.debug('scroll raw result: %s', hits) + hits = [hit['_source'] for hit in hits] + data = [model.parse_obj(hit) for hit in hits] + yield data def aggregate(self, aggs: Aggregation, **kwargs): """