callback.py 2.5 KB
Newer Older
陈正乐 committed
1
import sys
2 3 4

from abc import ABC, abstractmethod
import json
陈正乐 committed
5
from typing import List, Tuple
6
from langchain_core.documents import Document
陈正乐 committed
7
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore, str2hash_base64
8

陈正乐 committed
9 10
sys.path.append("../")

11 12

class DocumentCallback(ABC):
陈正乐 committed
13 14
    @abstractmethod  # 向量库储存前文档处理--
    def before_store(self, docstore: PgSqlDocstore, documents):
15
        pass
陈正乐 committed
16 17 18

    @abstractmethod  # 向量库查询后文档处理--用于结构建立
    def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
陈正乐 committed
19
            List[Tuple[Document, float]]:  # 向量库查询后文档处理
20
        pass
陈正乐 committed
21

22 23

class DefaultDocumentCallback(DocumentCallback):
陈正乐 committed
24
    def before_store(self, docstore: PgSqlDocstore, documents):
25 26 27 28 29 30 31
        output_doc = []
        for doc in documents:
            if "next_doc" in doc.metadata:
                doc.metadata["next_hash"] = str2hash_base64(doc.metadata["next_doc"])
                doc.metadata.pop("next_doc")
            output_doc.append(doc)
        return output_doc
陈正乐 committed
32 33

    def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
陈正乐 committed
34
            List[Tuple[Document, float]]:  # 向量库查询后文档处理
陈正乐 committed
35
        output_doc: List[Tuple[Document, float]] = []
36
        exist_hash = []
陈正乐 committed
37
        for doc, score in documents:
38 39 40 41 42 43
            print(exist_hash)
            dochash = str2hash_base64(doc.page_content)
            if dochash in exist_hash:
                continue
            else:
                exist_hash.append(dochash)
陈正乐 committed
44
            output_doc.append((doc, score))
45 46 47
            if len(output_doc) > number:
                return output_doc
            fordoc = doc
陈正乐 committed
48 49
            while "next_hash" in fordoc.metadata:
                if len(fordoc.metadata["next_hash"]) > 0:
50 51 52 53 54 55 56
                    if fordoc.metadata["next_hash"] in exist_hash:
                        break
                    else:
                        exist_hash.append(fordoc.metadata["next_hash"])
                    content = docstore.TXT_DOC.search(fordoc.metadata["next_hash"])
                    if content:
                        fordoc = Document(page_content=content[0], metadata=json.loads(content[1]))
陈正乐 committed
57
                        output_doc.append((fordoc, score))
58
                        if len(output_doc) > number:
陈正乐 committed
59
                            return output_doc
60 61 62 63
                    else:
                        break
                else:
                    break
陈正乐 committed
64
        return output_doc