k_store_test.py 2.64 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
import sys
sys.path.append("../") 
import time
from src.loader.load import loads_path,loads
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
    VEC_DB_DBNAME,
    VEC_DB_HOST,
    VEC_DB_PASSWORD,
    VEC_DB_PORT,
    VEC_DB_USER,
    EMBEEDING_MODEL_PATH,
    FAISS_STORE_PATH,
    SIMILARITY_SHOW_NUMBER,
    KNOWLEDGE_PATH,
    INDEX_NAME
)
from src.loader.callback import BaseCallback



# 当返回值中带有“思考题”字样的时候,默认将其忽略。
class localCallback(BaseCallback):
    def filter(self,title:str,content:str) -> bool:
        if len(title+content) == 0:
            return True
        return (len(title+content) / (len(title.splitlines())+len(content.splitlines())) < 20) or "思考题" in title


def test_faiss_from_dir():
    vecstore_faiss = VectorStore_FAISS(
    embedding_model_name=EMBEEDING_MODEL_PATH,
    store_path=FAISS_STORE_PATH,
    index_name=INDEX_NAME,
    info={"port":VEC_DB_PORT,"host":VEC_DB_HOST,"dbname":VEC_DB_DBNAME,"username":VEC_DB_USER,"password":VEC_DB_PASSWORD},
    show_number=3,
    reset=True)
    docs = loads_path(KNOWLEDGE_PATH,mode="paged",sentence_size=512,callbacks=[localCallback()])
    print(len(docs))
    last_doc = None
    docs1 = []
    for doc in docs:
        if not last_doc:
            last_doc = doc
            continue
        if "font-size" not in doc.metadata or "page_number" not in doc.metadata:
            continue
        if doc.metadata["font-size"] == last_doc.metadata["font-size"] and doc.metadata["page_number"] == last_doc.metadata["page_number"] and len(doc.page_content)+len(last_doc.page_content) < 512/4*3:
            last_doc.page_content += doc.page_content
        else:
            docs1.append(last_doc)
            last_doc = doc
    if last_doc:
        docs1.append(last_doc)
    docs = docs1
    print(len(docs))
    print(vecstore_faiss._faiss.index.ntotal)
    for i in range(0, len(docs), 300):
        vecstore_faiss._add_documents(docs[i:i+300 if i+300<len(docs) else len(docs)],need_split=True)
        print(vecstore_faiss._faiss.index.ntotal)
    vecstore_faiss._save_local()


def test_faiss_load():
    vecstore_faiss = VectorStore_FAISS(
    embedding_model_name=EMBEEDING_MODEL_PATH,
    store_path=FAISS_STORE_PATH,
    index_name=INDEX_NAME,
    info={"port":VEC_DB_PORT,"host":VEC_DB_HOST,"dbname":VEC_DB_DBNAME,"username":VEC_DB_USER,"password":VEC_DB_PASSWORD},
    show_number=SIMILARITY_SHOW_NUMBER,
    reset=False)
    print(vecstore_faiss._join_document(vecstore_faiss.get_text_similarity("请介绍一下你理解的国际结算业务")))


if __name__ == "__main__":
    test_faiss_from_dir()
    test_faiss_load()