rag_test.py 2.37 KB
Newer Older
1 2 3
import sys
sys.path.append('../')

4
from src.server.get_similarity import QAExt,ChatExtend
5 6 7 8 9 10 11 12 13 14 15

from langchain_openai import ChatOpenAI


base_llm = ChatOpenAI(
    openai_api_key='xxxxxxxxxxxxx',
    openai_api_base='http://192.168.10.14:8000/v1',
    model_name='Qwen2-7B',
    verbose=True
)

16 17 18 19 20 21 22 23 24 25 26 27 28
def test_qaext():
    ext = QAExt(base_llm)
    question = "明天适合去吗?"
    message = [
        ("我们明天去爬山吧", "好呀"),
        ("天气怎么样", "天气晴朗"),
    ]

    
    result =  ext.extend_query(question, message)
    print(result.content)


29 30 31
def test_chatextend():
    ext = ChatExtend(base_llm)
    message = [
tinywell committed
32
        ("明天去爬山怎么样", "好主意"),
33 34 35 36 37 38
        ("天气怎么样", "天气晴朗"),
    ]

    result =  ext.new_questions(messages=message)
    print(result.content)

tinywell committed
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
def test_rrf():
    from langchain_core.documents import Document

    from src.server.rerank import reciprocal_rank_fusion

    docs = [
        Document(page_content="我需要查找海拔最高的十座高山的信息。这可能需要从一个数据库或在线资源中获取数据。我将使用一个假设的数据库来获取这些信息。",metadata={"font-size": 12, "page_number": 1}),
        Document(page_content="2019年,中国成年男性的平均身高为170厘米,女性为160厘米。",metadata={"font-size": 12, "page_number": 2}),
        Document(page_content="2020年,中国成年男性的平均身高为171厘米,女性为160厘米。",metadata={"font-size": 12, "page_number": 3}),
    ]

    docs2 =    [
        Document(page_content="我们的 llm_engine 必须是一个可调用的函数,它接受一系列消息作为输入并返回文本。它还需要接受一个 stop_sequences 参数,该参数指示何时停止其生成。为了方便起见,我们直接使用软件包中提供的 HfEngine 类来获取一个调用我们的推理 API 的LLM引擎。",metadata={"font-size": 12, "page_number": 1}),
        Document(page_content="由于我们将代理初始化为一个 ReactJsonAgent ,因此它会自动获得一个默认的系统提示,该提示告诉LLM引擎逐步处理并生成 JSON 块作为工具调用(您可以根据需要替换此提示模板)。",metadata={"font-size": 12, "page_number": 2}),
    ]

    res = reciprocal_rank_fusion([(60,docs),(55,docs2)])
    print(res)

58
if __name__ == "__main__":
tinywell committed
59 60
    # test_chatextend()
    test_rrf()