Commit 558e898c by 陈正乐

init

parent 4662b93b
data/
docker/
docs/
images/
model/
src/tuning/
src/tools/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
dist/
eggs/
*.egg-info/
bin/
include/
lib/
local/
man/
share/
pip-wheel-metadata/
htmlcov/
.coverage
.tox/
.pytest_cache/
pytest.ini
# PyCharm
.idea/
# VSCode
.vscode/
# Jupyter Notebook
.ipynb_checkpoints
# Django
*.log
*.pot
*.pyc
local_settings.py
db.sqlite3
db.sqlite3-journal
media
# Flask
instance/
.webassets-cache
# Sphinx documentation
docs/_build/
model/
data/
exam/
.env
src/vector/faiss_store
src/scenarios/psbc/tag_memory_store/vectorstore
Python-3.11.0.tgz
Python-3.10.8.tgz
OpenSSL_1_1_1d.tar.gz
deps/averaged_perceptron_tagger.zip
deps/punkt.zip
deps/nltk.py
deps/averaged_perceptron_tagger.tar.gz
deps/punkt.tar.gz
FROM ubuntu:22.04 as base
# COPY sources.list /etc/apt/sources.list
RUN apt update && apt -y upgrade
RUN apt install -y gcc make wget perl zlib1g-dev libffi-dev libbz2-dev libreadline-dev liblzma-dev libsqlite3-dev
#RUN wget https://www.python.org/ftp/python/3.10.8/Python-3.10.8.tgz
#RUN wget https://github.com/openssl/openssl/archive/OpenSSL_1_1_1d.tar.gz
#download local
COPY deps/Python-3.10.8.tgz /
COPY deps/OpenSSL_1_1_1d.tar.gz /
RUN cd / && tar -zxf OpenSSL_1_1_1d.tar.gz
RUN cd openssl-OpenSSL_1_1_1d && ./config --prefix=/usr/local/openssl && make && make install
RUN rm -f /usr/bin/openssl /usr/lib64/openssl /usr/lib64/libssl.so \
&& ln -s /usr/local/openssl/bin/openssl /usr/bin/openssl \
&& ln -s /usr/local/openssl/include/openssl /usr/include/openssl \
&& ln -s /usr/local/openssl/lib/libssl.so /usr/lib64/libssl.so \
&& echo "/usr/local/openssl/lib" >> /etc/ld.so.conf \
&& ldconfig -v
RUN cd / && tar -zxf Python-3.10.8.tgz
RUN cd Python-3.10.8/ \
&& ./configure --enable-optimizations --prefix=/usr/local/python3.10 --with-openssl=/usr/local/openssl \
&& make && make install
# && rm -rf /Python-3.10.8.tgz /Python-3.10.8 /OpenSSL_1_1_1d.tar.gz /openssl-OpenSSL_1_1_1d
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
COPY deps/sources.list /etc/apt/sources.list
RUN apt update && apt -y upgrade
RUN apt install -y gcc wget perl vim net-tools libpq-dev
RUN useradd -m aigc && usermod -s /bin/bash aigc && usermod -G sudo aigc
COPY --from=base /usr/local/openssl /usr/local/openssl
COPY --from=base /usr/local/python3.10 /usr/local/python3.10
RUN rm -f /usr/bin/openssl /usr/lib64/openssl /usr/lib64/libssl.so \
&& ln -s /usr/local/openssl/bin/openssl /usr/bin/openssl \
&& ln -s /usr/local/openssl/include/openssl /usr/include/openssl \
&& ln -s /usr/local/openssl/lib/libssl.so /usr/lib64/libssl.so \
&& echo "/usr/local/openssl/lib" >> /etc/ld.so.conf \
&& ldconfig -v
RUN ln -s /usr/local/python3.10/bin/python3.10 /usr/local/bin/python3 \
&& ln -s /usr/local/python3.10/bin/pip3.10 /usr/local/bin/pip3 \
&& ln -s /usr/local/bin/python3 /usr/bin/python \
&& echo "export PATH=\$PATH:/usr/local/python3.10/bin" >> /etc/profile \
&& rm -f /usr/bin/pip && ln -s /usr/local/bin/pip3 /usr/bin/pip
ADD deps/punkt.tar.gz /usr/local/python3.10/nltk_data/tokenizers/
ADD deps/averaged_perceptron_tagger.tar.gz /usr/local/python3.10/nltk_data/taggers/
WORKDIR /home/aigc/
RUN mkdir .beai
RUN apt update && apt install -y libreoffice
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
COPY deps/requirements.txt requirements.txt
RUN pip install -r requirements.txt
# RUN python -m pip install --upgrade pip && pip install faiss-gpu
COPY . .
# WORKDIR /home/aigc/src/scenarios/spdsvb
# USER aigc
EXPOSE 5000
EXPOSE 8001
EXPOSE 8002
CMD ["bash"]
IMAGE_NAME = "brilliance/aigc_llm:0.2.0"
@PHONY: image
image:
docker build -t $(IMAGE_NAME) .
This source diff could not be displayed because it is too large. You can view the blob instead.
bitsandbytes==0.41.1
cpm-kernels==1.0.11
fastapi==0.100.0
Flask==2.3.2
jieba==0.42.1
langchain==0.0.278
peft==0.4.0
psycopg2==2.9.7
pydantic==1.10.12
requests==2.31.0
sentence-transformers==2.2.2
torch==2.0.1
transformers==4.31.0
uvicorn==0.23.1
unstructured==0.8.1
qianfan==0.0.5
faiss-gpu==1.7.2 # https://github.com/facebookresearch/faiss/blob/main/INSTALL.md
\ No newline at end of file
# 默认注释了源码镜像以提高 apt update 速度,如有需要可自行取消注释
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy main restricted universe multiverse
# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy main restricted universe multiverse
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-updates main restricted universe multiverse
# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-updates main restricted universe multiverse
deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-backports main restricted universe multiverse
# deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-backports main restricted universe multiverse
# deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-security main restricted universe multiverse
# # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-security main restricted universe multiverse
deb http://security.ubuntu.com/ubuntu/ jammy-security main restricted universe multiverse
# deb-src http://security.ubuntu.com/ubuntu/ jammy-security main restricted universe multiverse
# 预发布软件源,不建议启用
# deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-proposed main restricted universe multiverse
# # deb-src https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ jammy-proposed main restricted universe multiverse
\ No newline at end of file
#!/bin/sh
# This will start a code-server container and expose it at http://127.0.0.1:8080.
# It will also mount your current directory into the container as `/home/coder/project`
# and forward your UID/GID so that all file system operations occur as your user outside
# the container.
#
# Your $HOME/.config is mounted at $HOME/.config within the container to ensure you can
# easily access/modify your code-server config in $HOME/.config/code-server/config.json
# outside the container.
mkdir -p ~/.config
docker run -d --name code-server -p 8443:8080 \
-v "$HOME/.config:/home/coder/.config" \
-v "$HOME/:/home/coder/project" \
-u "$(id -u):$(id -g)" \
codercom/code-server:latest
\ No newline at end of file
version: '3.8'
services:
db:
image: postgres:15-alpine3.17
restart: always
environment:
POSTGRES_USER: vecdoc
POSTGRES_PASSWORD: vecdoc
POSTGRES_DB: vecdoc
ports:
- "5432:5432"
volumes:
- db-data:/var/lib/postgresql/data
volumes:
db-data:
\ No newline at end of file
#!/bin/bash
DB_CONTAINER_NAME=${1:-'vector_db'}
DB_NAME=${2:-'vecdoc'}
DB_USER=${3:-'vecdoc'}
STORE_PATH=${4:-'export.sql'}
# if [[ -z $1 || -z $2 || -z $3 ]];then
# echo "need input container name, db name, user name "
# echo "***.sh containername dbname username [storepath]"
# echo "default storepath ./export.sql"
# echo "- ./export.sql:/docker-entrypoint-initdb.d/export.sql"
# exit 1
# fi
# storepath=$4
# if [ -z $4 ];then
# storepath = "./export.sql"
set -x
docker exec -i $DB_CONTAINER_NAME pg_dump -d $DB_NAME -U $DB_USER > $STORE_PATH
set +x
\ No newline at end of file
{
"swagger": "2.0",
"info": {
"version": "1.0.0",
"title": "AI Chatbot API"
},
"basePath": "/aigc",
"schemes": [
"http"
],
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"paths": {
"/ask": {
"post": {
"summary": "Chat with the AI chatbot",
"description": "Send a question to the AI chatbot and receive a response",
"parameters": [
{
"name": "body",
"in": "body",
"description": "The request body",
"required": true,
"schema": {
"type": "object",
"properties": {
"question": {
"type": "string"
},
"modelOptions": {
"type": "object",
"properties": {
"isEnhancement": {
"type": "boolean",
"description": "Whether to use the enhancement model"
},
"isExpert": {
"type": "boolean",
"description": "Whether to use the expert model"
},
"isCommon": {
"type": "boolean",
"description": "Whether to use the common model"
},
"sliderTemp": {
"type": "number",
"description": "The temperature of the response"
}
}
},
"dialog": {
"type": "array",
"items": {
"type": "object",
"properties": {
"q": {
"type": "string",
"description": "The question in the dialog"
},
"a": {
"type": "string",
"description": "The answer in the dialog"
}
}
},
"description": "The dialog history"
}
}
}
}
],
"responses": {
"200": {
"description": "Successful response",
"schema": {
"type": "object",
"properties": {
"code": {
"type": "integer",
"format": "int32"
},
"msg": {
"type": "string"
},
"data": {
"type": "object",
"properties": {
"q": {
"type": "string",
"description": "The input question"
},
"a": {
"type": "string",
"description": "The response answer"
}
}
}
}
}
},
"400": {
"description": "Invalid request"
},
"500": {
"description": "Internal server error"
}
}
}
},
"/docqa": {
"post": {
"summary": "Answer a question based on uploaded documents",
"description": "This endpoint accepts a POST request with a JSON payload containing a query and optional parameters. It returns a JSON response containing the answer to the query.",
"consumes": [
"multipart/form-data"
],
"parameters": [
{
"name": "params",
"in": "formData",
"description": "JSON payload containing query and optional parameters",
"required": true,
"type": "object",
"properties":{
"chatid":{
"type":"string",
"description":"会话 id"
},
"query":{
"type":"string",
"description":"用户输入的问题"
},
"chain_type":{
"type":"string",
"description":"链类型"
},
"detail":{
"type":"boolean",
"description":"是否返回关联知识(default: true)"
},
"summary":{
"type":"boolean",
"description":"上传文档时是否进行总结(default: false)"
}
}
},
{
"name": "file",
"in": "formData",
"description": "File(s) to be uploaded for document retrieval",
"required": false,
"type": "file"
}
],
"responses": {
"200": {
"description": "Successful response",
"schema": {
"type": "object",
"properties": {
"code": {
"type": "integer",
"description": "HTTP status code"
},
"msg": {
"type": "string",
"description": "Response message"
},
"data": {
"type": "object",
"properties": {
"q": {
"type": "string",
"description": "Query string"
},
"a": {
"type": "string",
"description": "Answer to the query"
},
"similarity":{
"type":"array",
"description":"关联文档"
}
}
}
}
}
},
"400": {
"description": "Bad request",
"schema": {
"type": "object",
"properties": {
"code": {
"type": "integer",
"description": "HTTP status code"
},
"msg": {
"type": "string",
"description": "Error message"
}
}
}
},
"500": {
"description": "Internal server error",
"schema": {
"type": "object",
"properties": {
"code": {
"type": "integer",
"description": "HTTP status code"
},
"msg": {
"type": "string",
"description": "Error message"
}
}
}
}
}
}
}
}
}
\ No newline at end of file
#!/bin/bash
# Define source and destination directories
src_dir="/home/zfh/aird"
dest_dir="/home/zfh/aird_backup_his/$(date +%F)"
# clear the old backup
if [ -d "$dest_dir" ]; then
rm -rf "$dest_dir"
rm -f $dest_dir.tar.gz
fi
mkdir -p "$dest_dir"
if [ ! -d "$dest_dir/model/ckpt" ]; then
mkdir -p "$dest_dir/model/ckpt"
fi
# Copy directories and files to destination directory
exclusions=(tools tuning .env)
exclusions+=("__pycache__/")
rsync -av "${exclusions[@]/#/--exclude=}" "${src_dir}/src/" "${dest_dir}/src/"
# Define the list of directories to copy
# dirs_to_copy=(chatglm2-6b-qlora-spdsvb-INSv9 chatglm-6b-pt-spdsvb-INSv9-128-5e-3-3000 chatglm2-6b-pt-spdsvb-INSv11-128-5e-3-3010 chatglm2-6b-qlora-INSv11-rank16-1e-3-30)
# rsync -av "${dirs_to_copy[@]/#/${src_dir}\/model\/ckpt\/}" "${dest_dir}/model/ckpt/"
# rsync -av $src_dir/model/moka-ai/ $dest_dir/model/moka-ai/
rsync -av $src_dir/deps $dest_dir/
cp $src_dir/Dockerfile $dest_dir/Dockerfile
cp $src_dir/Makefile $dest_dir/Makefile
cp $src_dir/.dockerignore $dest_dir/.dockerignore
sed -i 's/\/home\/zfh/\/home\/ssvb/g' $dest_dir/src/common/consts.py
# Create a tar archive of the destination directory
# tar -czvf $dest_dir.tar.gz $dest_dir
tar -czvf /home/zfh/deploy/aird.$(date +%F).tar.gz -C $dest_dir .
rm -f /home/zfh/deploy/aird_backup
ln -sf $dest_dir /home/zfh/deploy/aird_backup
\ No newline at end of file
MODEL_PATH_ChatGLM = "/home/zfh/models/chatglm-6b"
MODEL_PATH_ChatGLM2 = "/home/zfh/models/chatglm2-6b"
MODEL_PATH_ChatGLM2_32K = "/home/zfh/models/chatglm2-6b-32k"
MODEL_NAME_ChatGLM = "THUDM/chatglm-6b"
MODEL_NAME_ChatGLM2 = "THUDM/chatglm2-6b"
INSTRUCTION_V1="你是浦发硅谷银行网银系统的专家,请帮助解答用户在使用过程中遇到的问题。\n"
\ No newline at end of file
from typing import List
from langchain.schema import Document
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager
)
class PrintRetrievalHandler(BaseCallbackHandler):
'''
回调,输出查询使用的相似性文档
'''
def __init__(self) -> None:
super().__init__()
self.similarity:dict = []
def on_retriever_start(self, query: str, **kwargs):
print(f"**Question:** {query}")
def on_retriever_end(self, documents, **kwargs):
self.similarity = [{"page_content":doc.page_content,"from_file":doc.metadata["filename"] or "","page_number":doc.metadata["page_number"] or 0} for doc in documents]
def getsimilarity(self)->List[Document]:
return self.similarity
\ No newline at end of file
from typing import List
from .prompts import ElementsPromptTemplate,ElementPromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
class ElementsExtractor:
def __init__(self, llm: BaseLanguageModel):
self.prompt = ElementsPromptTemplate(input_variables=["knowledge","elements"])
self.chain = LLMChain(llm=llm, prompt=self.prompt,verbose=True)
self.prompt_foreach = ElementPromptTemplate(input_variables=["knowledge","element"])
self.chain_foreach = LLMChain(llm=llm, prompt=self.prompt_foreach,verbose=True)
def extract(self, knowledge: str,elements: List[str]) -> List[str]:
output = self.chain.run({"knowledge":knowledge,"elements":elements})
lines = [line for line in output.split("\n") if line.strip()]
return lines
def extract_foreach(self, knowledge: str,elements: List[str]) -> List[str]:
lines = []
for e in elements:
output = self.chain_foreach.run({"knowledge":knowledge,"element":e})
lines.append(output)
print(output)
return lines
\ No newline at end of file
import os
from typing import Dict, Optional,List
from langchain.llms.base import BaseLLM,LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig,AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig
from pydantic import root_validator
class BaichuanLLM(LLM):
model_name: str = "baichuan-inc/Baichuan-13B-Chat"
quantization_bit: Optional[int] = None
tokenizer: AutoTokenizer = None
model: AutoModel = None
def _llm_type(self) -> str:
return "chatglm_local"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
if not values["model_name"]:
raise ValueError("No model name provided.")
model_name = values["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=False,trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
# device_map="auto",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(
model_name
)
if values["quantization_bit"]:
print(f"Quantized to {values['quantization_bit']} bit")
model = model.quantize(values["quantization_bit"]).cuda()
else:
model=model.half().cuda()
model = model.eval()
values["tokenizer"] = tokenizer
values["model"] = model
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
message = []
message.append({"role": "user", "content": prompt})
resp = self.model.chat(self.tokenizer,message)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return resp
\ No newline at end of file
import os
import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator
from pydantic import root_validator
import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig
import langchain
from langchain.llms.base import BaseLLM,LLM
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import aiohttp
import asyncio
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
class ChatGLMLocLLM(LLM):
model_name: str = "THUDM/chatglm-6b"
ptuning_checkpoint: str = None
quantization_bit: Optional[int] = None
pre_seq_len: Optional[int] = None
prefix_projection: bool = False
tokenizer: AutoTokenizer = None
model: AutoModel = None
def _llm_type(self) -> str:
return "chatglm_local"
# @root_validator()
def validate_environment(cls, values: Dict) -> Dict:
if not values["model_name"]:
raise ValueError("No model name provided.")
model_name = values["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name ,trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
if values["pre_seq_len"]:
config.pre_seq_len = values["pre_seq_len"]
if values["prefix_projection"]:
config.prefix_projection = values["prefix_projection"]
if values["ptuning_checkpoint"]:
ptuning_checkpoint = values["ptuning_checkpoint"]
print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True).half().cuda()
if values["pre_seq_len"]:
# P-tuning v2
model = model.half().cuda()
model.transformer.prefix_encoder.float().cuda()
if values["quantization_bit"]:
print(f"Quantized to {values['quantization_bit']} bit")
model = model.quantize(values["quantization_bit"])
model = model.eval()
values["tokenizer"] = tokenizer
values["model"] = model
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
resp,his = self.model.chat(self.tokenizer,prompt)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return resp
class ChatGLMSerLLM(LLM):
# 模型服务url
url: str = "http://127.0.0.1:8000"
chat_history: dict = []
out_stream: bool = False
cache: bool = False
@property
def _llm_type(self) -> str:
return "chatglm3-6b"
def get_num_tokens(self, text: str) -> int:
resp = self._post(url=self.url+"/tokens",query=self._construct_query(text))
if resp.status_code == 200:
resp_json = resp.json()
predictions = resp_json['response']
# display(self.convert_data(resp_json['history']))
return predictions
else:
return len(text)
def convert_data(self,data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
def _construct_query(self, prompt: str,temperature = 0.95) -> Dict:
"""构造请求体
"""
# self.chat_history.append({"role": "user", "content": prompt})
query = {
"prompt": prompt,
"history":self.chat_history,
"max_length": 4096,
"top_p": 0.7,
"temperature": temperature
}
return query
@classmethod
def _post(self, url: str,
query: Dict) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
with requests.session() as sess:
resp = sess.post(url,
json=query,
headers=_headers,
timeout=300)
return resp
async def _post_stream(self, url: str,
query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,stream=False) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
async with aiohttp.ClientSession() as sess:
async with sess.post(url, json=query,headers=_headers,timeout=300) as response:
if response.status == 200:
if stream and not run_manager:
print('not callable')
if run_manager:
for callable in run_manager.get_sync().handlers:
await callable.on_llm_start(None,None)
async for chunk in response.content.iter_any():
# 处理每个块的数据
if chunk and run_manager:
for callable in run_manager.get_sync().handlers:
# print(chunk.decode("utf-8"),end="")
await callable.on_llm_new_token(chunk.decode("utf-8"))
if run_manager:
for callable in run_manager.get_sync().handlers:
await callable.on_llm_end(None)
else:
raise ValueError(f'glm 请求异常,http code:{response.status}')
def _call(self, prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream = False,
**kwargs: Any) -> str:
query = self._construct_query(prompt=prompt,temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
# display("==============================")
# display(query)
# post
if stream or self.out_stream:
async def _post_stream():
await self._post_stream(url=self.url+"/stream",
query=query,run_manager=run_manager,stream=stream or self.out_stream)
asyncio.run(_post_stream())
return ''
else:
resp = self._post(url=self.url,
query=query)
if resp.status_code == 200:
resp_json = resp.json()
# self.chat_history.append({'q': prompt, 'a': resp_json['response']})
predictions = resp_json['response']
# display(self.convert_data(resp_json['history']))
return predictions
else:
raise ValueError(f'glm 请求异常,http code:{resp.status_code}')
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
query = self._construct_query(prompt=prompt,temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
await self._post_stream(url=self.url+"/stream",
query=query,run_manager=run_manager,stream=self.out_stream)
return ''
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.
"""
_param_dict = {
"url": self.url
}
return _param_dict
\ No newline at end of file
import os
import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator
from pydantic import root_validator
import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig
import langchain
from langchain.llms.base import BaseLLM,LLM
from langchain_openai import OpenAI
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
class ChatGLMSerLLM(OpenAI):
def get_token_ids(self, text: str) -> List[int]:
if self.model_name.__contains__("chatglm"):
## 发起http请求,获取token_ids
url = f"{self.openai_api_base}/num_tokens"
query = {"prompt": text,"model": self.model_name}
_headers = {"Content_Type": "application/json","Authorization": "chatglm "+self.openai_api_key}
resp = self._post(url=url,query=query,headers= _headers)
if resp.status_code == 200:
resp_json = resp.json()
print(resp_json)
predictions = resp_json['choices'][0]['text']
## predictions字符串转int
return [int(predictions)]
return [len(text)]
@classmethod
def _post(self, url: str,
query: Dict,headers: Dict) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
_headers.update(headers)
with requests.session() as sess:
resp = sess.post(url,
json=query,
headers=_headers,
timeout=300)
return resp
\ No newline at end of file
import logging
import os
from typing import Any, Dict, List, Mapping, Optional
from langchain.llms.base import BaseLLM,LLM
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
from enum import Enum
from pydantic import root_validator, Field
from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_message
logger = logging.getLogger(__name__)
class ModelType(Enum):
ERNIE = "ernie"
ERNIE_LITE = "ernie-lite"
SHEETS1 = "sheets1"
SHEETS2 = "sheets2"
SHEET_COMB = "sheet-comb"
LLAMA2_7B = "llama2-7b"
LLAMA2_13B = "llama2-13b"
LLAMA2_70B = "llama2-70b"
QFCN_LLAMA2_7B = "qfcn-llama2-7b"
BLOOMZ_7B="bloomz-7b"
MODEL_SERVICE_BASE_URL = "https://aip.baidubce.com/rpc/2.0/"
MODEL_SERVICE_Suffix = {
ModelType.ERNIE: "ai_custom/v1/wenxinworkshop/chat/completions",
ModelType.ERNIE_LITE: "ai_custom/v1/wenxinworkshop/chat/eb-instant",
ModelType.SHEETS1: "ai_custom/v1/wenxinworkshop/chat/besheet",
ModelType.SHEETS2: "ai_custom/v1/wenxinworkshop/chat/besheets2",
ModelType.SHEET_COMB: "ai_custom/v1/wenxinworkshop/chat/sheet_comb1",
ModelType.LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/llama_2_7b",
ModelType.LLAMA2_13B: "ai_custom/v1/wenxinworkshop/chat/llama_2_13b",
ModelType.LLAMA2_70B: "ai_custom/v1/wenxinworkshop/chat/llama_2_70b",
ModelType.QFCN_LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/qianfan_chinese_llama_2_7b",
ModelType.BLOOMZ_7B: "ai_custom/v1/wenxinworkshop/chat/bloomz_7b1",
}
class ErnieLLM(LLM):
"""
ErnieLLM is a LLM that uses Ernie to generate text.
"""
model_name: Optional[ModelType] = None
access_token: Optional[str] = ""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate the environment."""
# print(values)
model_name = ModelType(get_from_dict_or_env(values, "model_name", "model_name", str(ModelType.ERNIE)))
access_token = get_from_dict_or_env(values, "access_token", "ERNIE_ACCESS_TOKEN", "")
if not access_token:
raise ValueError("No access token provided.")
values["model_name"] = model_name
values["access_token"] = access_token
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
request = CompletionRequest(messages=[Message("user",prompt)])
bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token or "", request)
try:
# 你的代码
response = bot.get_response().result
# print("response: ",response)
return response
except Exception as e:
# 处理异常
print("exception:",e)
return e.__str__()
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "ernie"
# def _identifying_params(self) -> Mapping[str, Any]:
# return {
# "name": "ernie",
# }
def _get_model_service_url(model_name) -> str:
# print("_get_model_service_url model_name: ",model_name)
return MODEL_SERVICE_BASE_URL+MODEL_SERVICE_Suffix[model_name]
class ErnieChat(LLM):
model_name: ModelType
access_token: str
prefix_messages: List = Field(default_factory=list)
id: str = ""
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
msg = user_message(prompt)
request = CompletionRequest(messages=self.prefix_messages+[msg])
bot = ErnieBot(_get_model_service_url(self.model_name),self.access_token,request)
try:
# 你的代码
response = bot.get_response().result
if self.id == "":
self.id = bot.get_response().id
self.prefix_messages.append(msg)
self.prefix_messages.append(bot_message(response))
return response
except Exception as e:
# 处理异常
raise e
def _get_id(self) -> str:
return self.id
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "ernie"
\ No newline at end of file
from dataclasses import asdict, dataclass
from typing import List
from pydantic import BaseModel, Field
from enum import Enum
class MessageRole(str, Enum):
USER = "user"
BOT = "assistant"
@dataclass
class Message:
role: str
content: str
@dataclass
class CompletionRequest:
messages: List[Message]
stream: bool = False
user: str = ""
@dataclass
class Usage:
prompt_tokens: int
completion_tokens: int
total_tokens: int
@dataclass
class CompletionResponse:
id: str
object: str
created: int
result: str
need_clear_history: bool
ban_round: int = 0
sentence_id: int = 0
is_end: bool = False
usage: Usage = None
is_safe: bool = False
is_truncated: bool = False
class ErrorResponse(BaseModel):
error_code: int = Field(...)
error_msg: str = Field(...)
id: str = Field(...)
class ErnieBot():
url: str
access_token: str
request: CompletionRequest
def __init__(self, url: str, access_token: str, request: CompletionRequest):
self.url = url
self.access_token = access_token
self.request = request
def get_response(self) -> CompletionResponse:
import requests
import json
headers = {'Content-Type': 'application/json'}
params = {'access_token': self.access_token}
request_dict = asdict(self.request)
response = requests.post(self.url, params=params,data=json.dumps(request_dict), headers=headers)
# print(response.json())
try:
return CompletionResponse(**response.json())
except Exception as e:
print(e)
raise Exception(response.json())
def user_message(prompt: str) -> Message:
return Message(MessageRole.USER, prompt)
def bot_message(prompt: str) -> Message:
return Message(MessageRole.BOT, prompt)
\ No newline at end of file
import os
import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator
from pydantic import root_validator
from langchain.llms.base import BaseLLM,LLM
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import qianfan
from qianfan import ChatCompletion
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
class ChatERNIESerLLM(LLM):
# 模型服务url
chat_completion:ChatCompletion = None
# url: str = "http://127.0.0.1:8000"
chat_history: dict = []
out_stream: bool = False
cache: bool = False
model_name:str = "ERNIE-Bot"
# def __init__(self):
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
@property
def _llm_type(self) -> str:
return self.model_name
def get_num_tokens(self, text: str) -> int:
return len(text)
def convert_data(self,data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
def _call(self, prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream = False,
**kwargs: Any) -> str:
resp = self.chat_completion.do(model=self.model_name,messages=[{
"role": "user",
"content": prompt
}])
print(resp)
assert resp.code == 200
return resp.body["result"]
async def _post_stream(self,
query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream=False) -> Any:
"""POST请求
"""
async for r in await self.chat_completion.ado(model=self.model_name,messages=[query], stream=stream):
assert r.code == 200
if run_manager:
for callable in run_manager.get_sync().handlers:
await callable.on_llm_new_token(r.body["result"])
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
await self._post_stream(query={
"role": "user",
"content": prompt
},stream=True,run_manager=run_manager)
return ''
\ No newline at end of file
import os
import transformers
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig, DataCollatorForSeq2Seq
from peft import PeftModel
class ModelLoader:
def __init__(self, model_name_or_path, pre_seq_len=0, prefix_projection=False):
self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if pre_seq_len is not None and pre_seq_len > 0:
self.config.pre_seq_len = pre_seq_len
self.config.prefix_projection = prefix_projection
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config, trust_remote_code=True).half()
# self.model = self.model.cuda()
self.base_model = self.model
def quantize(self, quantization_bit):
if quantization_bit is not None:
print(f"Quantized to {quantization_bit} bit")
self.model = self.model.quantize(quantization_bit)
return self.model
def models(self):
return self.model, self.tokenizer
def collator(self):
return DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model)
def load_lora(self,ckpt_path,name="default"):
#训练时节约GPU占用
peft_loaded = PeftModel.from_pretrained(self.base_model,ckpt_path,adapter_name=name)
self.model = peft_loaded.merge_and_unload()
print(f"Load LoRA model successfully!")
def load_loras(self,ckpt_paths,name="default"):
if len(ckpt_paths)==0:
return
first = True
for name, path in ckpt_paths.items():
print(f"Load {name} from {path}")
if first:
peft_loaded = PeftModel.from_pretrained(self.base_model, path, adapter_name=name)
first = False
else:
peft_loaded.load_adapter(path,adapter_name=name)
peft_loaded.set_adapter(name)
self.model = peft_loaded
def load_prefix(self,ckpt_path):
prefix_state_dict = torch.load(os.path.join(ckpt_path, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float()
print(f"Load prefix model successfully!")
import logging
import os
from typing import Any, Dict, List, Mapping, Optional
from langchain.llms.base import BaseLLM,LLM
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
from enum import Enum
from pydantic import root_validator, Field
from .xinghuo import SparkApi
from .xinghuo.ws import SparkAPI
logger = logging.getLogger(__name__)
text =[]
# length = 0
def getText(role,content):
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while (getlength(text) > 8000):
del text[0]
return text
class SparkLLM(LLM):
"""
ErnieLLM is a LLM that uses Ernie to generate text.
"""
appid: str = Field(
None,
description="APPID",
)
api_key: str = Field(
None,
description="API_KEY",
)
api_secret: str = Field(
None,
description="API_SECRET",
)
version: str = Field(
None,
description="version",
)
api: SparkAPI = Field(
None,
description="api",
)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate the environment."""
# print(values)
appid = get_from_dict_or_env(values, "appid", "XH_APPID", "")
api_key = get_from_dict_or_env(values, "api_key", "XH_API_KEY", "")
api_secret = get_from_dict_or_env(values, "api_secret", "XH_API_SECRET", "")
version = values.get("version", "v1")
if not appid:
raise ValueError("No appid provided.")
if not api_key:
raise ValueError("No api_key provided.")
if not api_secret:
raise ValueError("No api_secret provided.")
values["appid"] = appid
values["api_key"] = api_key
values["api_secret"] = api_secret
api=SparkAPI(appid,api_key,api_secret,version)
values["api"]=api
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
question = self.getText("user",prompt)
try:
# 你的代码
# SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question)
self.api.call(question)
response = self.api.answer
return response
except Exception as e:
# 处理异常
print("exception:",e)
raise e
def getText(self,role,content):
text = []
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "xinghuo"
\ No newline at end of file
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from pydantic import root_validator
from typing import Dict, List, Optional
from transformers import PreTrainedModel, PreTrainedTokenizer
class WrapperLLM(LLM):
tokenizer: PreTrainedTokenizer = None
model: PreTrainedModel = None
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate the environment."""
# print(values)
if values.get("model") is None:
raise ValueError("No model provided.")
if values.get("tokenizer") is None:
raise ValueError("No tokenizer provided.")
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
resp,his = self.model.chat(self.tokenizer,prompt)
return resp
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "wrapper"
\ No newline at end of file
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
answer = ""
class Ws_Param(object):
# 初始化
def __init__(self, APPID, APIKey, APISecret, Spark_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(Spark_url).netloc
self.path = urlparse(Spark_url).path
self.Spark_url = Spark_url
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数,生成url
url = self.Spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
return url
# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)
# 收到websocket关闭的处理
def on_close(ws,one,two):
print(" ")
# 收到websocket连接建立的处理
def on_open(ws):
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
ws.send(data)
# 收到websocket消息的处理
def on_message(ws, message):
# print(message)
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
print(content,end ="")
global answer
answer += content
# print(1)
if status == 2:
ws.close()
def gen_params(appid, domain,question):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default"
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
def main(appid, api_key, api_secret, Spark_url,domain, question):
# print("星火:")
global answer
answer = ""
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = appid
ws.question = question
ws.domain = domain
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
import SparkApi
#以下密钥信息从控制台获取
appid = "XXXXXXXX" #填写控制台中获取的 APPID 信息
api_secret = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" #填写控制台中获取的 APISecret 信息
api_key ="XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" #填写控制台中获取的 APIKey 信息
#用于配置大模型版本,默认“general/generalv2”
domain = "general" # v1.5版本
# domain = "generalv2" # v2.0版本
#云端环境的服务地址
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
# Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
text =[]
# length = 0
def getText(role,content):
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while (getlength(text) > 8000):
del text[0]
return text
if __name__ == '__main__':
text.clear
while(1):
Input = input("\n" +"我:")
question = checklen(getText("user",Input))
SparkApi.answer =""
print("星火:",end = "")
SparkApi.main(appid,api_key,api_secret,Spark_url,domain,question)
getText("assistant",SparkApi.answer)
# print(str(text))
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlparse, urlencode
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
URL_V1_5="ws://spark-api.xf-yun.com/v1.1/chat"
URL_V2="ws://spark-api.xf-yun.com/v2.1/chat"
Domain_V1_5="general"
Domain_V2="generalv2"
class SparkAPI:
def __init__(self, APPID, APIKey, APISecret, Version="v1"):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
if Version == "v1":
self.Spark_url = URL_V1_5
self.domain = Domain_V1_5
elif Version == "v2":
self.Spark_url = URL_V2
self.domain = Domain_V2
self.host = urlparse(self.Spark_url).netloc
self.path = urlparse(self.Spark_url).path
self.answer = ""
def create_url(self):
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
url = self.Spark_url + '?' + urlencode(v)
return url
def on_error(self, ws, error):
print("### error:", error)
def on_close(self, ws, one, two):
print(" ")
def on_open(self, ws):
thread.start_new_thread(self.run, (ws,))
def run(self, ws, *args):
data = json.dumps(self.gen_params(appid=self.APPID, domain=self.domain, question=ws.question))
ws.send(data)
def on_message(self, ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
# print(content, end="")
self.answer += content
if status == 2:
ws.close()
def gen_params(self, appid, domain, question):
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default"
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
def call(self, question):
self.answer = ""
wsUrl = self.create_url()
websocket.enableTrace(False)
ws = websocket.WebSocketApp(wsUrl, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open)
ws.question = question
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
\ No newline at end of file
from abc import ABC, abstractmethod
class BaseCallback(ABC):
@abstractmethod
def filter(self,title:str,content:str) -> bool: #return True舍弃当前段落
pass
\ No newline at end of file
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
from .config import SENTENCE_SIZE
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
self.sentence_size = sentence_size
def split_text1(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = re.sub('\s', ' ', text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :;
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list
def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
text = text.rstrip() # 段尾如果有多余的\n就去掉它
# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
ls = [i for i in text.split("\n") if i]
for ele in ls:
if len(ele) > self.sentence_size:
ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)
ele1_ls = ele1.split("\n")
for ele_ele1 in ele1_ls:
if len(ele_ele1) > self.sentence_size:
ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
ele2_ls = ele_ele2.split("\n")
for ele_ele2 in ele2_ls:
if len(ele_ele2) > self.sentence_size:
ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
ele2_id = ele2_ls.index(ele_ele2)
ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
ele2_id + 1:]
ele_id = ele1_ls.index(ele_ele1)
ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
id = ls.index(ele)
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
return ls
# 文本分句长度
SENTENCE_SIZE = 100
ZH_TITLE_ENHANCE = False
\ No newline at end of file
from typing import List
from langchain.docstore.document import Document
import re
def under_non_alpha_ratio(text: str, threshold: float = 0.5):
"""Checks if the proportion of non-alpha characters in the text snippet exceeds a given
threshold. This helps prevent text like "-----------BREAK---------" from being tagged
as a title or narrative text. The ratio does not count spaces.
Parameters
----------
text
The input string to test
threshold
If the proportion of non-alpha characters exceeds this threshold, the function
returns False
"""
if len(text) == 0:
return False
alpha_count = len([char for char in text if char.strip() and char.isalpha()])
total_count = len([char for char in text if char.strip()])
try:
ratio = alpha_count / total_count
return ratio < threshold
except:
return False
def is_possible_title(
text: str,
title_max_word_length: int = 20,
non_alpha_threshold: float = 0.5,
) -> bool:
"""Checks to see if the text passes all of the checks for a valid title.
Parameters
----------
text
The input text to check
title_max_word_length
The maximum number of words a title can contain
non_alpha_threshold
The minimum number of alpha characters the text needs to be considered a title
"""
# 文本长度为0的话,肯定不是title
if len(text) == 0:
print("Not a title. Text is empty.")
return False
# 文本中有标点符号,就不是title
ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z"
ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN)
if ENDS_IN_PUNCT_RE.search(text) is not None:
return False
# 文本长度不能超过设定值,默认20
# NOTE(robinson) - splitting on spaces here instead of word tokenizing because it
# is less expensive and actual tokenization doesn't add much value for the length check
if len(text) > title_max_word_length:
return False
# 文本中数字的占比不能太高,否则不是title
if under_non_alpha_ratio(text, threshold=non_alpha_threshold):
return False
# NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles
if text.endswith((",", ".", ",", "。")):
return False
if text.isnumeric():
print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore
return False
# 开头的字符内应该有数字,默认5个字符内
if len(text) < 5:
text_5 = text
else:
text_5 = text[:5]
alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5))))
if not alpha_in_text_5:
return False
return True
def zh_title_enhance(docs: List[Document]) -> List[Document] | None:
title = None
if len(docs) > 0:
for doc in docs:
if is_possible_title(doc.page_content):
doc.metadata['category'] = 'cn_Title'
title = doc.page_content
elif title:
doc.page_content = f"下文与({title})有关。{doc.page_content}"
return docs
else:
print("文件不存在")
from typing import List
from .prompts import QuestionGeneratePromptTemplate,AnswerPromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.prompts import StringPromptTemplate,PromptTemplate
class QAGenerator:
def __init__(self, llm: BaseLanguageModel):
self.prompt = QuestionGeneratePromptTemplate(input_variables=["knowledge","question_number"])
self.chain = LLMChain(llm=llm, prompt=self.prompt)
self.answer_prompt = AnswerPromptTemplate(input_variables=["knowledge","question"])
self.answer_chain = LLMChain(llm=llm, prompt=self.answer_prompt)
def generate_questions(self, knowledge: str,question_number: int | None=None) -> List[str]:
output = self.chain.run({"knowledge":knowledge,"question_number":question_number})
lines = [line for line in output.split("\n") if line.strip() and line.startswith("问:")]
return lines
def generate_answer(self, knowledge: str, question: str) -> str:
answer = self.answer_chain.run({"knowledge":knowledge,"question":question})
return answer
def generate(self, knowledge: str, question_number=3):
questions = self.generate_questions(knowledge, question_number)
if len(questions) == 0:
return None
questions = [question.replace("问:","") for question in questions if question.startswith("问:")]
answers = []
for question in questions:
answer = self.generate_answer(knowledge, question)
answers.append(answer)
return [(question,answer) for question,answer in zip(questions,answers)]
prompt_template="""从下面这段话中提取出关键信息,生成 {question_number} 个问题,并回答相关问题。
输出格式为一问一答,问题以"问:"开头,答案以"答:"开头。
不需要标注问题序号, 问题和答案相互映射,每个问题相互独立。
{knowledge}
"""
class TrainData:
def __init__(self, llm: BaseLanguageModel):
self.prompt = PromptTemplate.from_template(prompt_template)
self.chain = LLMChain(llm=llm, prompt=self.prompt,verbose=False)
def generate(self, knowledge: str, question_number=3):
res=self.chain.run({"knowledge":knowledge,"question_number":question_number})
print(res)
qas = []
q, a = None, ''
lines = res.split("\n")
for line in lines:
line = line.strip()
if line.startswith('问:'):
if q is not None and a is not None:
qas.append((q, a))
q = line.replace('问:', '')
a = ''
elif line.startswith('答:'):
a += line.replace('答:', '') + '\n'
elif a is not None:
a += line + '\n'
if q is not None and a is not None:
qas.append((q, a))
return qas
# qas = [qa for qa in qas if qa.strip()]
# questions = [qa for qa in qas if qa.startswith("问:")]
# answers = [qa for qa in qas if qa.startswith("答:")]
# if len(questions) != len(answers):
# print(ValueError("questions and answers are not equal"))
# return []
# return [(question.strip().replace("问:",""),answer.strip().replace("答:","")) for question,answer in zip(questions,answers)]
# return [{"Q":question,"A":answer} for question,answer in zip(questions,answers)]
\ No newline at end of file
from typing import Any
from langchain.prompts import StringPromptTemplate,PromptTemplate
from pydantic import BaseModel, validator
template="""
{knowledge}
请从上述内容中归纳总结出 {question_number} 个问题,问题以"问:"开头,以问号结尾,用空行分隔。
"""
prompt=PromptTemplate.from_template(template)
class QuestionGeneratePromptTemplate(StringPromptTemplate, BaseModel):
def format(self, **kwargs) -> str:
if "question_number" not in kwargs or kwargs["question_number"] is None:
question_number = 5
else:
question_number = kwargs["question_number"]
if "knowledge" not in kwargs:
raise ValueError("knowledge is required")
knowledge = kwargs["knowledge"]
return prompt.format(question_number=question_number, knowledge=knowledge)
def _prompt_type(self):
return "question-generator"
template_answer="""
请参考以下内容中回答问题:{question},回答以"答:"开头。
下面是你参考的内容:
{knowledge}
"""
class AnswerPromptTemplate(StringPromptTemplate, BaseModel):
def format(self, **kwargs: Any) -> str:
if "knowledge" not in kwargs:
raise ValueError("knowledge is required")
if "question" not in kwargs:
raise ValueError("question is required")
return template_answer.format(question=kwargs["question"], knowledge=kwargs["knowledge"])
\ No newline at end of file
from langchain import PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
question_rdf_prompt_template="""
你的任务是帮助我重新整理问题,使得问题中的指代信息更加明确,你的输出只会是原问题或者重新整理之后的问题,不包含其他内容。
==============================
{history}
以上是对话历史。
==============================
这是我的问题:
{question}
==============================
如果你认为这个问题中存在不明确的指代信息,请从对话历史中提取出相关信息,重新定义问题并直接输出,不要增加无关内容。
否则,请直接输出原问题。直接输出原问题的格式为:
{question}
"""
question_rdf_prompt=PromptTemplate.from_template(question_rdf_prompt_template)
class QuestionRDF:
def __init__(self,llm:BaseLanguageModel):
self.llm=llm
self.chain=LLMChain(llm=llm,prompt=question_rdf_prompt,verbose=True)
def generate(self,history:str,question:str):
res=self.chain.run({"history":history,"question":question})
return res
accelerate==0.21.0
aiofiles==23.1.0
aiohttp==3.8.5
aiosignal==1.3.1
altair==5.0.1
annotated-types==0.5.0
anyio==3.7.1
async-timeout==4.0.2
attrs==23.1.0
bitsandbytes==0.41.1
blinker==1.6.2
certifi==2023.7.22
cffi==1.15.1
chardet==5.1.0
charset-normalizer==3.2.0
click==8.1.6
cmake==3.27.0
contourpy==1.1.0
cpm-kernels==1.0.11
cryptography==41.0.2
cycler==0.11.0
dataclasses-json==0.5.13
datasets==2.14.0
dill==0.3.7
et-xmlfile==1.1.0
faiss==1.7.4
fastapi==0.100.0
ffmpy==0.3.1
filelock==3.12.2
filetype==1.2.0
Flask==2.3.2
fonttools==4.41.1
frozenlist==1.4.0
fsspec==2023.6.0
gevent==23.9.0.post1
gradio==3.38.0
gradio_client==0.2.10
greenlet==2.0.2
h11==0.14.0
httpcore==0.17.3
httpx==0.24.1
huggingface-hub==0.16.4
idna==3.4
itsdangerous==2.1.2
jieba==0.42.1
Jinja2==3.1.2
joblib==1.3.1
jsonschema==4.18.4
jsonschema-specifications==2023.7.1
kiwisolver==1.4.4
langchain==0.0.242
langsmith==0.0.14
latex2mathml==3.76.0
linkify-it-py==2.0.2
lit==16.0.6
loguru==0.7.0
lxml==4.9.3
Markdown==3.4.4
markdown-it-py==2.2.0
MarkupSafe==2.1.3
marshmallow==3.20.1
matplotlib==3.7.2
mdit-py-plugins==0.3.3
mdtex2html==1.2.0
mdurl==0.1.2
mpmath==1.3.0
msg-parser==1.2.0
multidict==6.0.4
multiprocess==0.70.15
mypy-extensions==1.0.0
networkx==3.1
nltk==3.8.1
numexpr==2.8.4
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1691056235090/work
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
olefile==0.46
openai==0.28.0
openapi-schema-pydantic==1.2.4
openpyxl==3.1.2
orjson==3.9.2
packaging==23.1
pandas==2.0.3
pdf2image==1.16.3
pdfminer.six==20221105
peft==0.4.0
Pillow==10.0.0
pipdeptree==2.13.0
psutil==5.9.5
psycopg2==2.9.7
pyarrow==12.0.1
pycparser==2.21
pydantic==1.10.12
pydantic_core==2.4.0
pydub==0.25.1
pypandoc==1.11
pyparsing==3.0.9
python-dateutil==2.8.2
python-docx==0.8.11
python-dotenv==1.0.0
python-magic==0.4.27
python-multipart==0.0.6
python-pptx==0.6.21
pytz==2023.3
PyYAML==6.0.1
referencing==0.30.0
regex==2023.6.3
requests==2.31.0
rouge-chinese==1.0.3
rpds-py==0.9.2
safetensors==0.3.1
scikit-learn==1.3.0
scipy==1.11.1
semantic-version==2.10.0
sentence-transformers==2.2.2
sentencepiece==0.1.99
six==1.16.0
sniffio==1.3.0
SQLAlchemy==2.0.19
starlette==0.27.0
sympy==1.12
tabulate==0.9.0
tenacity==8.2.2
threadpoolctl==3.2.0
tiktoken==0.4.0
tokenizers==0.13.3
toolz==0.12.0
torch==2.0.1
torchkeras==3.9.2
torchvision==0.15.2
tqdm==4.65.0
transformers==4.31.0
triton==2.0.0
typing-inspect==0.9.0
typing_extensions==4.7.1
tzdata==2023.3
uc-micro-py==1.0.2
unstructured==0.8.1
urllib3==2.0.4
uvicorn==0.23.1
uWSGI @ file:///croot/uwsgi_1688631110587/work
websocket==0.2.1
websocket-client==1.6.2
websockets==11.0.3
Werkzeug==2.3.6
xlrd==2.0.1
XlsxWriter==3.1.2
xxhash==3.2.0
yarl==1.9.2
zope.event==5.0
zope.interface==6.0
import sys
sys.path.append("../..")
import gradio as gr
import torch
from contract.extraction import ElementsExtractor
from llm.chatglm import ChatGLMLocLLM
from llm.ernie import ErnieLLM
from llm.baichuan import BaichuanLLM
from loader.load import load_file,load
from common import consts
from flask.cli import load_dotenv
load_dotenv()
from argparse import Namespace
from llm.loader import ModelLoader
cfg = Namespace()
#model
cfg.model_name_or_path = consts.MODEL_PATH_ChatGLM2_32K
cfg.lora_checkpoint = '/home/zfh/aird/model/ckpt/chatglm-6b-lora-spdsvb-INSv8-1e-03-30'
cfg.pre_seq_len = None
cfg.quantization_bit = None
loader = ModelLoader(cfg.model_name_or_path, cfg.quantization_bit)
# loader.load_lora(cfg.lora_checkpoint)
model,tokenizer = loader.models()
model = model.eval()
# max_length=1000
# Define the Gradio interface
def contract(chatbot, file, input, history, max_length):
print("contract",file,input,max_length)
chatbot.append((input, ""))
if file is not None:
content = load(file.name)
content="\n".join([d.page_content for d in content])
print(len(content))
if content is not None and len(content) > 0:
if len(history) == 0:
prompt = f"{content}请基于上述内容回答以下问题:\n{input}"
else:
prompt = input
else:
prompt = input
# print("prompt: ",prompt)
for response, history in model.stream_chat(tokenizer, prompt, history, max_length=max_length): # type: ignore
chatbot[-1] = (input, response)
yield chatbot, history
def reset(history):
history=[]
file.value=None
return history
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">商业合同要素提取</h1>""")
history = gr.State([])
chatbot = gr.Chatbot()
# with gr.Column(scale=4):
# with gr.Row():
# input=gr.Textbox(label="输入文本", type="text", lines=5)
with gr.Row():
with gr.Column(scale=4):
file = gr.File(label="上传文件")
with gr.Column(scale=4):
input=gr.Textbox(label="输入文本", type="text", lines=5)
with gr.Column(scale=4):
with gr.Row():
max_length = gr.Slider(1000, 30000, value=30000, step=1000, label="单次提取使用的文本长度", interactive=True)
with gr.Row():
submit_btn=gr.Button("开始提取")
reset_btn=gr.Button("重置")
reset_btn.click(reset,inputs=[history],outputs=[history])
submit_btn.click(contract,inputs=[chatbot, file, input, history, max_length],outputs=[chatbot,history])
demo.queue().launch(share=True)
\ No newline at end of file
from flask import Flask, jsonify, request
import tempfile
import shutil
import sys
sys.path.append("../..")
from llm.chatglm import ChatGLMLocLLM
from llm.ernie import ErnieLLM,ModelType
from llm.baichuan import BaichuanLLM
from loader.load import load_file,load
from contract.extraction import ElementsExtractor
app = Flask(__name__)
temp_dir = tempfile.mkdtemp()
# llm = ChatGLMLocLLM(model_name="/home/zfh/models/chatglm-6b")
llm = ErnieLLM(model_name=ModelType.ERNIE_LITE,access_token="24.a2dab1f44fdee40ff5fe1923d8dbdfcb.2592000.1692415837.282335-32870719")
extractor=ElementsExtractor(llm=llm)
@app.route('/contract', methods=['POST'])
def contract():
# 获取文件对象
print(request.files)
file = request.files['file']
if file is None:
return jsonify({'message': 'Error: could not load file'})
filename =temp_dir + '/' + file.filename
file.save(filename)
# 解析其中的参数
elements = request.form.get('elements').split(",")
max_length = request.form.get('max_length')
max_length = int(max_length) if max_length is not None else 1000
print(file,elements,max_length)
# 调用模型
docs = load(filename)
if docs is None:
return jsonify({'message': 'Error: could not load file'})
print(len(docs))
content = []
content_len = 0
values={}
for e in elements:
values[e]=""
for d in docs:
if content_len+len(d.page_content)>max_length:
doc = "\n".join(content)
eles = extractor.extract(doc,elements)
for e in eles:
try:
k,v = e.split(":",maxsplit=1)
k = k.strip()
v = v.strip()
if v is not None and v != "" and v!="未知" and k in elements:
values[k]=v+","+values[k] if k in values else v
except Exception as exp:
print(exp)
print(e)
continue
print("\n".join([f"{k}:{v}" for k,v in values.items()]))
content=[d.page_content]
content_len=len(d.page_content)
else:
content.append(d.page_content)
content_len+=len(d.page_content)
return jsonify(values)
@app.route('/example', methods=['GET'])
def example():
return jsonify({'message': 'Hello World!'})
if __name__ == '__main__':
app.run(debug=True,port=8080,host='0.0.0.0')
\ No newline at end of file
import sys
sys.path.append("../..")
import gradio as gr
import torch
from contract.extraction import ElementsExtractor
from llm.chatglm import ChatGLMLocLLM
from llm.ernie import ErnieLLM
from llm.baichuan import BaichuanLLM
from loader.load import load_file,load
from common import consts
from flask.cli import load_dotenv
load_dotenv()
# Load the model
# llms = ["ChatGLM","ChatGLM2","Ernie","ChatGLM2-32K"]
# llm = ChatGLMLocLLM(model_name=consts.MODEL_PATH_ChatGLM2_32K)
from llm.loader import ModelLoader
from llm.wrapper import WrapperLLM
loader = ModelLoader(consts.MODEL_PATH_ChatGLM2)
model,tokenizer = loader.models()
llm=WrapperLLM(model=model,tokenizer=tokenizer)
# llm = ErnieLLM()
def extract_values(values,content, elements, extractor):
doc = "\n".join(content)
eles = extractor.extract_foreach(doc, elements)
# eles = extractor.extract(doc, elements)
for e in eles:
try:
k, v = e.split(":", maxsplit=1)
k = k.strip()
v = v.strip()
if v is not None and v != "" and v != "未知" and k in elements:
values[k] = v + "," + values[k] if k in values else v
except Exception as exp:
print(exp)
print(e)
continue
return values
extractor=ElementsExtractor(llm=llm)
elements = ["合同号","买方","卖方","合同金额","合同签订日期","装运标记","甲方","乙方","甲方地址","乙方地址"]
# max_length=1000
# Define the Gradio interface
def contract(file,elements,max_length):
print(file,elements,max_length)
docs = load(file.name)
if docs is None:
return "Error: could not load file"
print(len(docs))
content = []
content_len = 0
values={k:"" for k in elements}
for d in docs:
if content_len+len(d.page_content)>max_length:
values = extract_values(values,content, elements, extractor)
print("\n".join([f"{k}:{v}" for k,v in values.items()]))
content=[d.page_content]
content_len=len(d.page_content)
else:
content.append(d.page_content)
content_len+=len(d.page_content)
values = extract_values(values,content, elements, extractor)
return "\n".join([f"{k}:{v}" for k,v in values.items()])
def change_llm_type(llm_type):
print("change_llm_type",llm_type)
global llm,extractor
del llm
llm=ErnieLLM()
torch.cuda.empty_cache()
if llm_type=="ChatGLM":
llm = ChatGLMLocLLM(model_name=consts.MODEL_PATH_ChatGLM)
elif llm_type=="ChatGLM2":
llm = ChatGLMLocLLM(model_name=consts.MODEL_PATH_ChatGLM2)
elif llm_type=="ChatGLM2-32k":
llm = ChatGLMLocLLM(model_name=consts.MODEL_PATH_ChatGLM2_32K)
elif llm_type=="Ernie":
llm = ErnieLLM()
elif llm_type=="baichuan-13b":
llm = BaichuanLLM(model_name="../../models/Baichuan-13B-Chat",quantization_bit=8)
else:
llm = ErnieLLM()
if llm is not None:
extractor=ElementsExtractor(llm=llm)
return llm_type
def add_element(ele_new):
print("add_element",elements,ele_new)
elements.append(ele_new)
return {ele_group:gr.update(choices=elements),
ele_new_box:gr.update(value="")}
def reset():
output.value=""
file.value=None
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">商业合同要素提取</h1>""")
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
file = gr.File(label="上传文件")
with gr.Row():
submit_btn=gr.Button("开始提取")
# reset_btn=gr.Button("重置", type="reset")
# reset_btn.click(reset)
with gr.Row():
output=gr.Textbox(label="提取结果", type="text", lines=20)
with gr.Column(scale=1):
with gr.Row():
max_length = gr.Slider(1000, 30000, value=5000, step=1000, label="单次提取使用的文本长度", interactive=True)
# if llm.model_name==consts.MODEL_PATH_ChatGLM2_32K:
# max_length.value=30000
# with gr.Row():
# llm_type = gr.Radio(llms, label="语言模型类型", value="ChatGLM2", interactive=True)
# llm_type.change(change_llm_type, inputs=[llm_type],outputs=[llm_type])
with gr.Row():
ele_group = gr.CheckboxGroup(choices=elements, label="需要提取的元素", value=elements, interactive=True)
with gr.Row():
ele_new_box = gr.Textbox(label="新增元素", type="text", lines=1)
ele_new_btn = gr.Button("新增")
ele_new_btn.click(add_element,inputs=[ele_new_box],outputs=[ele_group,ele_new_box])
submit_btn.click(contract,inputs=[file,ele_group,max_length],outputs=output)
demo.queue().launch(share=True)
\ No newline at end of file
#!/bin/bash
# Set the path to the server.py script
SERVER_PATH=server.py
# Set the default values for the arguments
MODEL_NAME_OR_PATH="../../../model/chatglm2-6b"
CHECKPOINT=None
CHECKPOINT_PATH="../../../model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000"
PRE_SEQ_LEN=128
QUANTIZATION_BIT=8
PORT=8002
# Call the server.py script with the parsed arguments
python $SERVER_PATH \
--model_name_or_path $MODEL_NAME_OR_PATH \
--checkpoint $CHECKPOINT \
--checkpoint_path $CHECKPOINT_PATH \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit $QUANTIZATION_BIT \
--port $PORT
\ No newline at end of file
#!/bin/bash
# Set the path to the server.py script
SERVER_PATH=server.py
# Set the default values for the arguments
MODEL_NAME_OR_PATH="../../../model/chatglm2-6b"
CHECKPOINT=lora
CHECKPOINT_PATH="../../../model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000"
QUANTIZATION_BIT=8
PORT=8001
# Call the server.py script with the parsed arguments
python $SERVER_PATH \
--model_name_or_path $MODEL_NAME_OR_PATH \
--checkpoint $CHECKPOINT \
--checkpoint_path $CHECKPOINT_PATH \
--quantization_bit $QUANTIZATION_BIT \
--port $PORT
\ No newline at end of file
import argparse
import time
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
import uvicorn, json, datetime
import torch
import asyncio
import os
import sys
sys.path.append("../..")
from typing import AsyncIterable, Awaitable
from pydantic import BaseModel
import uvicorn
# from dotenv import load_dotenv
import signal
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def build_history(history):
result = []
for item in history if history else []:
result.append((item['q'], item['a']))
return result
def convert_data(data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
class StreamRequest(BaseModel):
"""Request body for streaming."""
message: str
stop_stream = False
def signal_handler(signal, frame):
global stop_stream
stop_stream = True
async def send_message(message: str,history=[],max_length=2048,top_p=0.7,temperature=0.95) -> AsyncIterable[str]:
global model, tokenizer, stop_stream
count = 0
old_len = 0
print(message)
output = ''
for response, history in model.stream_chat(tokenizer, message, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature):
# print(old_len,count)
if stop_stream:
stop_stream = False
break
else:
output = response[old_len:]
print(output, end='',flush=True)
# print(output)
old_len = len(response)
signal.signal(signal.SIGINT, signal_handler)
yield f"{output}"
print("")
# yield f"\n"
# print()
app = FastAPI()
@app.post("/stream")
async def stream(request: Request):
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = build_history(json_post_list.get('history'))
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
return StreamingResponse(send_message(prompt,history=history,max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95), media_type="text/plain")
@app.post("/")
async def create_item(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = build_history(json_post_list.get('history'))
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
response, history = model.chat(tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
@app.post("/tokens")
async def get_num_tokens(request: Request):
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
tokens = tokenizer.encode(prompt, add_special_tokens=False)
print("=======================================")
print("=======================================")
print(len(tokens),prompt)
print("=======================================")
print("=======================================")
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": len(tokens),
"status": 200,
"time": time
}
return answer
def parse_args():
parser = argparse.ArgumentParser(description='ChatGLM2-6B Server')
parser.add_argument('--model_name_or_path', type=str, default='THUDM/chatglm2-6b', help='模型id或local path')
parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint类型(None、ptuning、lora)')
parser.add_argument('--checkpoint_path', type=str, default='../../../model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000', help='checkpoint路径')
parser.add_argument('--pre_seq_len', type=int, default=128, help='prefix 长度')
parser.add_argument('--quantization_bit', type=int, default=None, help='是否量化')
parser.add_argument('--port', type=int, default=8000, help='端口')
parser.add_argument('--host', type=str, default='0.0.0.0', help='host')
# parser.add_argument('--max_input_length', type=int, default=512, help='instruction + input的最大长度')
# parser.add_argument('--max_output_length', type=int, default=1536, help='output的最大长度')
return parser.parse_args()
if __name__ == '__main__':
cfg=parse_args()
## ----------- load model --------------
from llm.loader import ModelLoader
start = time.time()
if cfg.checkpoint == "lora":
# lora 微调 checkpoint 及模型加载
loader = ModelLoader(cfg.model_name_or_path)
loader.load_lora(cfg.checkpoint_path)
elif cfg.checkpoint == "ptuning":
# ptuning v2 微调 checkpoint 及模型加载
loader = ModelLoader(cfg.model_name_or_path, cfg.pre_seq_len, False)
loader.load_prefix(cfg.checkpoint_path)
else:
loader = ModelLoader(cfg.model_name_or_path)
model,tokenizer = loader.models()
if cfg.quantization_bit is not None:
model = loader.quantize(cfg.quantization_bit)
model.cuda().eval()
uvicorn.run(app, host=cfg.host, port=cfg.port, workers=1)
\ No newline at end of file
from fastapi import FastAPI, HTTPException, Header
from pydantic import BaseModel
import sys
sys.path.append("../..")
from llm.chatglm import ChatGLMSerLLM
from langchain import LLMChain
from langchain.vectorstores.faiss import FAISS
app = FastAPI()
import math
from fastapi.middleware.cors import CORSMiddleware
from scenarios.psbc.model_serve import modelcall_prase, CHATGLM_PROMPT
# 添加 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有域访问,也可以指定特定域名
allow_credentials=True,
allow_methods=["*"], # 允许所有HTTP方法
allow_headers=["*"], # 允许所有HTTP头
)
model_url = "http://192.168.10.93:8000"
def customize_exception(code: int, message: str, data: dict):
return {
"code":code,
"message":message,
"data":data
}
@app.get("/query")
async def query(query:str=None):
# modelcall_prase没有将url和CHATGLM_PROMPT变量封装进去,在这里需要进行声明。
base_llm=ChatGLMSerLLM(url=model_url)
chose_llm = LLMChain(llm=base_llm, prompt=CHATGLM_PROMPT)
try:
result, title, summary, tags, category, subcategories = modelcall_prase(chose_llm, query)
except Exception as e:
print("模型调用过程或json解析过程出现错误:", e)
return customize_exception(40012, str(e), {})
return customize_exception(200, "success", {
"result": result,
"title": title,
"summary": summary,
"tags": tags,
"category": category,
"subcategories": subcategories
})
@app.get("/tags")
async def tags(tags_native:list, faiss_vectorstore:FAISS, threshold:float=0.7):
score_threshold = (1-threshold) * math.sqrt(2)
tags_new = []
for tag_native in tags_native:
res = faiss_vectorstore.similarity_search_with_score(query=tag_native, score_threshold=score_threshold)
for document_object, _ in res:
tags_new.append(document_object.page_content)
return customize_exception(200, "success", {
"tags_native": tags_native,
"tags_new": tags_new
})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8888)
# -*- coding: utf-8 -*-
import os, sys
import pandas as pd
sys.path.append("../..")
import gradio as gr
import argparse
from llm.chatglm import ChatGLMSerLLM
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
import re
from llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
chatglm3_prompt_qa = """请根据以下资料内容生成问答对,一共输出{num_selector}个问题。(请你按照以下格式进行回答)
Q:问题
A:答案
Q:问题
A:答案
...
资料内容如下所示:
{context}"""
qianfan_prompt_qa = """请根据以下资料内容生成问答对,一共输出{num_selector}个问题。(请你按照以下格式进行回答)
Q:问题
A:答案
Q:问题
A:答案
...
资料内容如下所示:
'''
{context}
'''"""
CHATGLM3_PROMPT_QA = PromptTemplate(input_variables=["context"],template=chatglm3_prompt_qa)
QIANFAN_PROMPT_QA = PromptTemplate(input_variables=["context"],template=qianfan_prompt_qa)
async def async_chat_qa(input_text, model,num_selector):
# yield gr.DataFrame(pd.DataFrame(), col_count=(3, "fixed"), row_count=(3, "fixed")) ,""
global qianfanchain_qa
global chatglm3chain_qa
# Create an asynchronous callback handler
callback = AsyncIteratorCallbackHandler()
# Define an asynchronous function to wrap another asynchronous function and signal completion or exceptions using an event
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn # Wait for the provided asynchronous function to complete
except Exception as e:
# TODO: Handle exceptions - here, we simply print the exception information
import traceback
traceback.print_exc()
print(f"Caught exception: {e}")
finally:
event.set() # Set the event to indicate completion
# Create a task to perform message generation with ChatOpenAI and monitor the completion event of the callback handler
if model == "ernie":
task = asyncio.create_task(wrap_done(qianfanchain_qa.arun({"context":input_text,"num_selector":num_selector},callbacks=[callback]),callback.done))
else:
task = asyncio.create_task(wrap_done(chatglm3chain_qa.arun({"context":input_text,"num_selector":num_selector},callbacks=[callback]),callback.done))
print("*"*20)
# Iterate asynchronously to obtain tokens from the callback handler
text=""
async for token in callback.aiter():
text=text+token
yield f"{text}"
await task # Wait for the task to complete
def on_select(evt: gr.SelectData, df): # SelectData is a subclass of EventData
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
# 删除选择行
if evt.index[1] == 2:
df.drop(df.index[evt.index[0]],axis=0,inplace=True)
return gr.DataFrame(df,interactive=True)
else:
return df
def parse_qa(output_text):
output_text = output_text.replace("\n\n", "\n")
output_text += "\n"
qa_pairs = re.findall(r"Q:(.*?)A:(.*?)\n", output_text, re.DOTALL)
lenth = len(qa_pairs)
# formatted_qa_pairs = [(q,a) for q, a in qa_pairs]
df = pd.DataFrame({
"Q": [q for q, _ in qa_pairs],
"A": [a for _, a in qa_pairs],
"删除": ["删除" for _, _ in qa_pairs]
})
return gr.DataFrame(df, row_count=(lenth, "fixed"))
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">辅助生成知识库</h1>""")
# with gr.Row():
# input_text = gr.Textbox(show_label=True, placeholder="输入需要处理的文档...", lines=10)
with gr.Row():
input_text = gr.Textbox(show_label=True, placeholder="输入需要处理的文档...", lines=10, scale=9)
model_selector = gr.Dropdown(choices=["ernie","chatglm3"], label="请选择一个模型", scale=1, min_width=50, value="chatglm3")
with gr.Row():
num_selector = gr.Slider(minimum=0, maximum=10, value=5, label="请选择问题数量",step=1)
with gr.Row():
qaBtn = gr.Button("QA问答对生成")
dataframe = gr.DataFrame(visible=True,interactive=True,column_widths=["30%", "60%", "10%"], col_count=(3, "fixed"), row_count=(1, "fixed"))
dataframe.select(on_select, inputs=[dataframe], outputs=[dataframe])
gr.Markdown("""---""")
output_text = gr.Textbox(show_label=True, placeholder="输出...", lines=10)
# clearBtn = gr.Button("清除")
# clearBtn.click(clear, [], [dataframe, output_text])
qaBtn.click(async_chat_qa, [input_text, model_selector,num_selector], [ output_text], queue=True).then(
parse_qa,
[output_text],
[dataframe],
queue=False
)
if __name__ == "__main__":
parese = argparse.ArgumentParser()
parese.add_argument("--port", type=int, default=7658)
parese.add_argument("--host", type=str, default="192.168.0.66")
parese.add_argument("--base_llm_url", type=str, default="http://192.168.22.106:8003")
args = parese.parse_args()
global base_llm_url, qianfanchain_qa, chatglm3chain_qa
base_llm_url=os.environ.get('LLM_URL_BASE', None)
if not base_llm_url:
base_llm_url=args.base_llm_url
base_llm1=ChatGLMSerLLM(url=base_llm_url)
base_llm2=ChatERNIESerLLM(chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
qianfanchain_qa = LLMChain(llm=base_llm2, prompt=QIANFAN_PROMPT_QA,llm_kwargs={"temperature":0.9})
chatglm3chain_qa = LLMChain(llm=base_llm1, prompt=CHATGLM3_PROMPT_QA,llm_kwargs={"temperature":0.9})
demo.queue().launch(share=False, inbrowser=True,server_name=args.host,server_port=args.port)
# -*- coding: utf-8 -*-
import os, sys
sys.path.append("../..")
import gradio as gr
import argparse
from llm.chatglm import ChatGLMSerLLM
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
import re
from llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from difflib import Differ
chatglm3_prompt_typos = """请仔细阅读以下资料,更正资料中的错别字,并按照以下格式进行输出。
更正后的资料:
文本内容
文本内容:
{context}"""
qianfan_prompt_typos = """'''
{context}
'''
已知有上述文本,现需要你仔细阅读上述资料,更正资料中的错别字之后进行输出。你需要按照如下格式进行输出:
更正之后的资料:
"""
CHATGLM3_PROMPT_TYPOS = PromptTemplate(input_variables=["context"],template=chatglm3_prompt_typos)
QIANFAN_PROMPT_TYPOS = PromptTemplate(input_variables=["context"], template=qianfan_prompt_typos)
async def async_chat_typos(input_text, model):
# yield gr.DataFrame(col_count=(3, "fixed")) ,""
global base_llm1, base_llm2, step
qianfanchain_stc = LLMChain(llm=base_llm2, prompt=QIANFAN_PROMPT_TYPOS,llm_kwargs={"temperature":0.9})
chatglm3chain_stc = LLMChain(llm=base_llm1, prompt=CHATGLM3_PROMPT_TYPOS,llm_kwargs={"temperature":0.9})
# Create an asynchronous callback handler
callback = AsyncIteratorCallbackHandler()
# Define an asynchronous function to wrap another asynchronous function and signal completion or exceptions using an event
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn # Wait for the provided asynchronous function to complete
except Exception as e:
# TODO: Handle exceptions - here, we simply print the exception information
import traceback
traceback.print_exc()
print(f"Caught exception: {e}")
finally:
event.set() # Set the event to indicate completion
# Create a task to perform message generation with ChatOpenAI and monitor the completion event of the callback handler
if model == "ernie":
task = asyncio.create_task(wrap_done(qianfanchain_stc.arun({"context":input_text},callbacks=[callback]),callback.done))
else:
task = asyncio.create_task(wrap_done(chatglm3chain_stc.arun({"context":input_text},callbacks=[callback]),callback.done))
print("*"*20)
# Iterate asynchronously to obtain tokens from the callback handler
text=""
async for token in callback.aiter():
text=text+token
yield f"{text}"
await task # Wait for the task to complete
def parse_typos(output_text):
print(output_text)
output_text = output_text.replace("\n\n", "\n")
output_text = '\n'.join(output_text.splitlines()[1:])
return output_text
def diff_texts(text1, text2):
d = Differ()
return [
(token[2:], token[0] if token[0] != " " else None)
for token in d.compare(text1, text2)
]
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">文章纠错</h1>""")
# with gr.Row():
with gr.Row():
input_text = gr.Textbox(show_label=True, placeholder="输入需要处理的文档...", lines=10, scale=9)
model_selector = gr.Dropdown(choices=["ernie","chatglm3"], label="请选择一个模型", scale=1, min_width=50, value="chatglm3")
submitBtn = gr.Button("提交")
output_text = gr.Textbox(show_label=True, placeholder="输出...", lines=10)
gr.Interface(
diff_texts,
[input_text, output_text],
gr.HighlightedText(
label="Diff",
combine_adjacent=True,
show_legend=True,
color_map={"+": "red", "-": "green"}),
theme=gr.themes.Base()
)
submitBtn.click(
async_chat_typos, [input_text, model_selector],[output_text], queue=True
).then(
parse_typos, [output_text], [output_text], queue=False
)
if __name__ == "__main__":
parese = argparse.ArgumentParser()
parese.add_argument("--port", type=int, default=7655)
parese.add_argument("--host", type=str, default="192.168.0.66")
parese.add_argument("--base_llm_url", type=str, default="http://192.168.22.106:8003")
args = parese.parse_args()
global base_llm_url,llmchain,step
base_llm_url=os.environ.get('LLM_URL_BASE', None)
if not base_llm_url:
base_llm_url=args.base_llm_url
base_llm1=ChatGLMSerLLM(url=base_llm_url)
base_llm2=ChatERNIESerLLM(chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
# llmchain = LLMChain(llm=base_llm, prompt=CHATGLM3_PROMPT_TYPOS,llm_kwargs={"temperature":0.9})
# openai = OpenAI(model_name="chatglm3-6b", openai_api_key="token1",openai_api_base=base_llm_url + "/v1")
# llmchain = LLMChain(llm=openai, prompt=CHATGLM_PROMPT_CT,verbose=True,llm_kwargs={"temperature":1.0})
demo.queue().launch(share=False, inbrowser=True,server_name=args.host,server_port=args.port)
import sys
sys.path.append("../..")
from llm.chatglm import ChatGLMSerLLM
from langchain import LLMChain
import json
from langchain.prompts import StringPromptTemplate,PromptTemplate
import re
chatglm_prompt = """请仔细阅读以下文本内容,并根据文本的主题和关键信息,生成一个包含摘要、标签列表和分类信息的JSON结构。确保摘要准确捕捉文本的核心内容,标签列表精确反映文本的主旨,并且分类信息与文本内容紧密相关。请按照以下JSON结构格式输出结果:(注意,你的回复中只能有json)
```{{
"title": 给文本添加标题,标题简洁明了地概括文本的主题。,
"summary": 这里是文本的摘要,简洁明了地概括文本的主要内容。,
"tags": 为文本添加标签,结果为list。,
"category": 给文本添加分类,结果为字符串。
}}```
文本内容如下所示:
{context}
"""
CHATGLM_PROMPT = PromptTemplate(input_variables=["context"],template=chatglm_prompt)
chatglm_prompt_ct = """请仔细阅读以下文本内容,并根据文本的主题和关键信息,生成一个包含摘要、标签列表和分类信息的JSON结构。确保摘要准确捕捉文本的核心内容,标签列表精确反映文本的主旨,并且分类信息与文本内容紧密相关。请按照以下JSON结构格式输出结果:(注意,你的回复中只能有json)
```{{
"title": 给文本添加标题,标题简洁明了地概括文本的主题。,
"summary": 这里是文本的摘要,概括文本的主要内容。,
"tags": 提取文本中的关键词,结果为list。,
"category": 给文本添加分类,结果为字符串。
}}```
注意:category必须属于集合之一:[{category}]
------------------------
文本内容如下所示:
{context}"""
CHATGLM_PROMPT_CT = PromptTemplate(input_variables=["context", "category"],template=chatglm_prompt_ct)
query = """为什么说现代经济是信用经济?
答:(1)现代经济运作的特点。信用关系无处不在;信用规模呈现不断扩张趋势;信用结构日趋复杂化 。
(2)从信用关系的各部门分析:盈余与赤字、债权与债务。
(3)从信用关系中的主体来分析。由于经济中广泛存在着专门调剂资金余缺的金融机构,借贷双方不需要直接见面,通过金融机构作为中介人,便可解决资金的融通,从而进一步促进了信用和信用关系的发展。信用的发展,又大大促进了生产力和经济的发展。"""
def extract_first_code_block(input_string):
if input_string[0] == '{' and input_string[-1] == '}':
return input_string
pattern1 = r'```json(.*?)```'
match1 = re.search(pattern1, input_string, re.DOTALL)
if match1:
return match1.group(1)
else:
pattern2 = r'```(.*?)```'
match2 = re.search(pattern2, input_string, re.DOTALL)
if match2:
return match2.group(1)
else:
return None
def prase(answer):
json_data = extract_first_code_block(answer) # 从输出中拿json数据
# print("==================================================")
# print(json_data)
# print("==================================================")
data = {}
try:
data = json.loads(json_data)
except Exception as e:
print(f"在大模型给的结果转化为json形式的时候出现错误: {e}")
title = data.get("title", "")
summary = data.get("summary", "")
tags = data.get("tags", "")
category = data.get("category", "")
return title, summary, tags, category
def modelcall_prase(llmchain,input):
result = llmchain.invoke({"context":input})
title, summary, tags, category = prase(result["text"])
return result, title, summary, tags, category
def modelcall_prase_ct(llmchain, input, category):
result = llmchain.invoke({"context":input, "category":category})
title, summary, tags, category = prase(result["text"])
return result, title, summary, tags, category
chatglm_prompt_tags = """请认真阅读下段文本内容,并根据其内容生成一个或多个标签。你需要回复一个列表,该数组中的元素为与文本内容相符号的标签,除此之外无需回复其他内容。文本内容如下:
{context}"""
chatglm_prompt_title = """请认真阅读下段文本内容,并根据其内容信息生成其标题名称。你只需要给出标题,无需输出其他内容。文本内容如下:
{context}"""
chatglm_prompt_summary = """请认真阅读下段文本内容,并用一句话描述其中心内容。你只需要输出该句话,无需输出其他内容。文本内容如下:
{context}"""
chatglm_prompt_category = """请认真阅读下段文本内容,并在预定义好的类别中匹配与文本所示内容最相似的一个。你只需要输出该类别,无需输出其他内容。注意:必须是给出的类别中的一个,如果找不到与文本所述相匹配的类型则输出“不限”即可。文本内容如下:
{context}
预定义的类别如下:
{category}"""
CHATGLM_PROMPT_TAGS = PromptTemplate(input_variables=["context"],template=chatglm_prompt_tags)
CHATGLM_PROMPT_TITLE = PromptTemplate(input_variables=["context"],template=chatglm_prompt_title)
CHATGLM_PROMPT_SUMMARY = PromptTemplate(input_variables=["context"],template=chatglm_prompt_summary)
CHATGLM_PROMPT_CATEGORY = PromptTemplate(input_variables=["context", "category"],template=chatglm_prompt_category)
if __name__ == "__main__":
base_llm=ChatGLMSerLLM(url="http://192.168.10.93:8000")
# chose_llm = LLMChain(llm=base_llm, prompt=CHATGLM_PROMPT_TAGS)
# tags_get1 = chose_llm.run({"context":query})
# print(tags_get1)
# chose_llm = LLMChain(llm=base_llm, prompt=CHATGLM_PROMPT_CATEGORY)
# print(chose_llm.run({"context":query}))
# chose_llm = LLMChain(llm=base_llm, prompt=CHATGLM_PROMPT_TITLE)
# print(chose_llm.run({"context":query}))
# chose_llm = LLMChain(llm=base_llm, prompt=CHATGLM_PROMPT_SUMMARY)
# print(chose_llm.run({"context":query, "category":"热点问题、手机银行、转账汇款、账户管理、个人网银、信贷业务、自助设备"}))
chose_llm = LLMChain(llm=base_llm, prompt=CHATGLM_PROMPT_CT)
result, title, summary, tags, category, subcategories = modelcall_prase_ct(chose_llm,input = query,category = "金融,财政,计算机,学生,教师,学校,社会")
print("result:",result)
print("title",title)
print("summary",summary)
print("tags",tags)
print("category",category)
print("subcategories",subcategories)
# """{
# "summary": "储蓄国债(电子式)和储蓄国债(凭证式)的特点和操作说明。",
# "tags": [
# "储蓄国债",
# "电子式",
# "凭证式",
# "付息",
# "兑付",
# "提前兑取",
# "非交易过户",
# "财产证明",
# "质押贷款"
# ],
# "category": "金融产品",
# "subcategories": [
# "债券",
# "国债",
# "投资"
# ]
# }"""
\ No newline at end of file
from langchain.prompts import PromptTemplate
chatglm3_prompt_tfq = """<|system|>
你是一个可以将一段文本根据其内容生成可以用于考试的判断题,并按照一定格式进行输出的工具。具体的输出格式如下所示:
试题1:
题目:。
答案:正确/错误。
解析:用到的资料中的内容。
试题2:
题目:。
答案:正确/错误。
解析:用到的资料中的内容。
...
<|user|>
请你根据下面这段文本生成{num_selector}个判断题。
文本内容如下:
{context}"""
chatglm3_prompt_mcq = """<|system|>
你是一个可以将一段文本根据其内容生成可以用于考试的选择题,并按照一定格式进行输出的工具,具体的输出格式如下所示:
试题1:
题目:。
A.选项内容 B.选项内容 C.选项内容 D.选项内容
正确答案:给出具体选项。
试题2:
题目:。
A.选项内容 B.选项内容 C.选项内容 D.选项内容
正确答案:给出具体选项。
...
<|user|>
请你根据下面这段文本生成{num_selector}个选择题。
文本内容如下:
{context}"""
qianfan_prompt_tfq = """'''
{context}
'''
请根据上面提供的知识资料,生成可以作为考试的判断题,并给出正确答案,一共输出{num_selector}个问题。按照如下格式进行回答:
试题1:
题目:试题内容。
正确答案:正确/错误
解析:选择原因(如果用到原资料中的内容,请列出来)
试题2:
题目:试题内容。
正确答案:正确/错误
解析:选择原因(如果用到原资料中的内容,请列出来)
...
"""
qianfan_prompt_mcq = """'''
{context}
'''
请你根据上面这一段文本的内容生成可以用于考试的选择题,并按照一定格式进行输出,一共输出{num_selector}个问题。具体的输出格式如下所示:
试题1:
题目:
A.选项内容
B.选项内容
C.选项内容
D.选项内容
正确答案:给出具体选项。
试题2:
题目:
A.选项内容
B.选项内容
C.选项内容
D.选项内容
正确答案:给出具体选项。
...
"""
chatglm3_prompt_qa = """请根据以下资料内容生成{num_selector}个问答对。(请你按照以下格式进行回答)
Q:问题
A:答案
Q:问题
A:答案
...
资料内容如下所示:
{context}"""
qianfan_prompt_qa = """'''
{context}
'''
请根据以上资料内容生成{num_selector}个问答对。(请你按照以下格式进行回答)
Q:问题
A:答案
Q:问题
A:答案
...
"""
# chatglm3_prompt_struct_s1 = """<|system|>
# 你是一个可以将一段文本根据其内容将其划分为{p_number}段并按照一定格式进行输出的工具。你需要按照如下的格式进行输出:
# 段落1:
# 段落内容
# 段落2:
# 段落内容
# ...
# (注意:必须保证所有的段落内容之和为文本原文)
# <|user|>
# 文本内容如下:
# {context}"""
chatglm3_prompt_struct_s2 = """{context}
请为上述文本取一个标题(除标题之外不可输出其他内容)。"""
# qianfan_prompt_struct_s1 = """文本内容如下:
# '''
# {context}
# '''
# 你需要将上述文本根据其内容将其划分为{p_number}段并按照如下的格式进行输出:
# 段落1:
# 段落内容
# 段落2:
# 段落内容
# ...
# (注意:原始文本内容必须全部被分割)"""
qianfan_prompt_struct_s2 = """'''
{context}
'''
请为上述文本取一个标题(输出绝对不可出现除标题之外的字):"""
chatglm3_prompt_typos1 = """{context}
已知有上述文本,现需要你仔细阅读上述资料,更正资料中的错别字并将文本语气更改为{tone}语气之后进行输出。你需要按照如下的格式进行输出:
更正之后的资料:
"""
qianfan_prompt_typos1 = """'''
{context}
'''
已知有上述文本,现需要你仔细阅读上述资料,更正资料中的错别字并将文本语气更改为{tone}语气之后进行输出。你需要按照如下的格式进行输出:
更正之后的资料:
"""
chatglm3_prompt_typos2 = """{context}
已知有上述文本,现需要你仔细阅读上述资料,更正资料中的错别字之后进行输出。你需要按照如下的格式进行输出:
更正之后的资料:
"""
qianfan_prompt_typos2 = """'''
{context}
'''
已知有上述文本,现需要你仔细阅读上述资料,更正资料中的错别字之后进行输出。你需要按照如下的格式进行输出:
更正之后的资料:
"""
qianfan_prompt_stc1 = """'''
{context}
'''
请为上述文本取一个简短的标题,不要出现任何标点符号。"""
qianfan_prompt_stc2 = """'''
{context}
'''
请为上述文本取一个简短的标题,不要出现任何标点符号。"""
chatglm3_prompt_stc1 = """{context}
请为上述文本取一个标题(除标题之外不可输出其他内容)。"""
chatglm3_prompt_stc2 = """{context}
请为上述文本取一个标题(除标题之外不可输出其他内容)。"""
QIANFAN_PROMPT_STC1 = PromptTemplate(input_variables=["context"], template=qianfan_prompt_stc1)
QIANFAN_PROMPT_STC2 = PromptTemplate(input_variables=["context"], template=qianfan_prompt_stc2)
CHATGLM3_PROMPT_STC1 = PromptTemplate(input_variables=["context"], template=chatglm3_prompt_stc1)
CHATGLM3_PROMPT_STC2 = PromptTemplate(input_variables=["context"], template=chatglm3_prompt_stc2)
CHATGLM3_PROMPT_TFQ = PromptTemplate(input_variables=["context"], template=chatglm3_prompt_tfq)
CHATGLM3_PROMPT_MCQ = PromptTemplate(input_variables=["context"], template=chatglm3_prompt_mcq)
QIANFAN_PROMPT_TFQ = PromptTemplate(input_variables=["context"], template=qianfan_prompt_tfq)
QIANFAN_PROMPT_MCQ = PromptTemplate(input_variables=["context"], template=qianfan_prompt_mcq)
CHATGLM3_PROMPT_QA = PromptTemplate(input_variables=["context"], template=chatglm3_prompt_qa)
QIANFAN_PROMPT_QA = PromptTemplate(input_variables=["context"], template=qianfan_prompt_qa)
# CHATGLM3_PROMPT_STRUCT_S1 = PromptTemplate(input_variables=["context"], template=chatglm3_prompt_struct_s1)
CHATGLM3_PROMPT_STRUCT_S2 = PromptTemplate(input_variables=["context"], template=chatglm3_prompt_struct_s2)
# QIANFAN_PROMPT_STRUCT_S1 = PromptTemplate(input_variables=["context"], template=qianfan_prompt_struct_s1)
QIANFAN_PROMPT_STRUCT_S2 = PromptTemplate(input_variables=["context"], template=qianfan_prompt_struct_s2)
CHATGLM3_PROMPT_TYPOS1 = PromptTemplate(input_variables=["context"], template=chatglm3_prompt_typos1)
QIANFAN_PROMPT_TYPOS1 = PromptTemplate(input_variables=["context"], template=qianfan_prompt_typos1)
CHATGLM3_PROMPT_TYPOS2 = PromptTemplate(input_variables=["context"], template=chatglm3_prompt_typos2)
QIANFAN_PROMPT_TYPOS2 = PromptTemplate(input_variables=["context"], template=qianfan_prompt_typos2)
\ No newline at end of file
from langchain.vectorstores.faiss import FAISS
from langchain.embeddings.base import Embeddings
from langchain.embeddings.huggingface import (
HuggingFaceEmbeddings,
)
import math
from langchain_core.documents import Document
def build_and_save_vectorfaiss(documents,folder_path: str, index_name: str = "psbc_tags",embedding: str = "C:\\Users\\15663\\AI\\models\\bge-large-zh-v1.5"):
embedding = HuggingFaceEmbeddings(model_name=embedding)
faiss_vectorstore = FAISS.from_documents(documents=documents,embedding=embedding)
faiss_vectorstore.save_local(folder_path=folder_path,index_name=index_name)
def load_vectorfaiss(folder_path: str, index_name: str = "index",embedding: str = "C:\\Users\\15663\\AI\\models\\bge-large-zh-v1.5"):
embedding = HuggingFaceEmbeddings(model_name=embedding)
return FAISS.load_local(folder_path=folder_path,embeddings=embedding,index_name=index_name)
def search_tags(native_tags: list, faiss_vectorstore: FAISS, threshold: float=0.7):
score_threshold = (1-threshold) * math.sqrt(2)
tags = []
advice_tags = []
for native_tag in native_tags:
res = faiss_vectorstore.similarity_search_with_score(query=native_tag, score_threshold=score_threshold, k=1)
if len(res) == 0:
advice_tags.append(native_tag)
else:
tags.append(res[0][0].page_content)
return tags, advice_tags
if __name__ == "__main__":
tags = ["个人存款账户实名制","金融机构","身份证件","实名证件","规定施行","国家外汇管理局","行政许可","实施办法","外汇管理","政务服务","电子化办理","听证","监督检查"]
docs = [Document(page_content=tag) for tag in tags]
build_and_save_vectorfaiss(documents=docs,folder_path="./vectorstore",index_name="psbc_tags")
faiss_vectorstore = load_vectorfaiss(folder_path="./vectorstore",index_name="psbc_tags")
threshold = 0.7
result = search_tags(["金融机构"], faiss_vectorstore, threshold=threshold)
print(result)
\ No newline at end of file
import os, sys
from os import path
sys.path.append("../")
from abc import ABC, abstractmethod
import json
from typing import List,Any,Tuple,Dict
from langchain.schema import Document
from vector.pgsqldocstore import PgSqlDocstore,str2hash_base64
class DocumentCallback(ABC):
@abstractmethod #向量库储存前文档处理--
def before_store(self,docstore:PgSqlDocstore,documents:[Document]) -> [Document]:
pass
@abstractmethod #向量库查询后文档处理--用于结构建立
def after_search(self,docstore:PgSqlDocstore,documents:List[Tuple[Document, float]],number:int = 1000) -> List[Tuple[Document, float]]: #向量库查询后文档处理
pass
class DefaultDocumentCallback(DocumentCallback):
def before_store(self,docstore:PgSqlDocstore,documents:[Document]) -> [Document]:
output_doc:[Document] = []
for doc in documents:
if "next_doc" in doc.metadata:
doc.metadata["next_hash"] = str2hash_base64(doc.metadata["next_doc"])
doc.metadata.pop("next_doc")
output_doc.append(doc)
return output_doc
def after_search(self,docstore:PgSqlDocstore,documents:List[Tuple[Document, float]],number:int = 1000) -> List[Tuple[Document, float]]: #向量库查询后文档处理
output_doc:List[Tuple[Document, float]] = []
exist_hash = []
for doc,score in documents:
print(exist_hash)
dochash = str2hash_base64(doc.page_content)
if dochash in exist_hash:
continue
else:
exist_hash.append(dochash)
output_doc.append((doc,score))
if len(output_doc) > number:
return output_doc
fordoc = doc
while ("next_hash" in fordoc.metadata):
if len(fordoc.metadata["next_hash"])>0:
if fordoc.metadata["next_hash"] in exist_hash:
break
else:
exist_hash.append(fordoc.metadata["next_hash"])
content = docstore.TXT_DOC.search(fordoc.metadata["next_hash"])
if content:
fordoc = Document(page_content=content[0], metadata=json.loads(content[1]))
output_doc.append((fordoc,score))
if len(output_doc) > number:
return output_doc
else:
break
else:
break
return output_doc
\ No newline at end of file
import re
import time
from pydantic import BaseModel
from langchain.prompts import StringPromptTemplate,PromptTemplate
from langchain import LLMChain
from qa.question import QuestionRDF
from similarity import VectorStore_FAISS
prompt_expert_template = """你是浦发硅谷银行网银系统的专家,请帮助解答用户在使用过程中遇到的问题。
{question}
"""
prompt_history_template = """{history}
上面是之前的对话,你可以继续回答用户的问题。
{question}
"""
prompt_enhancement_template = """{similarity}
请结合上述内容回答以下问题,不要提无关内容:
{question}
"""
prompt_enhancement_history_template = """{history}
上面是之前的对话,下面是可参考的内容。
{similarity}
请结合上述内容回答以下问题,不要提无关内容:
{question}
"""
class Chatbot:
def __init__(self, model, vectorstore_faiss:VectorStore_FAISS,tokenizer=None, base_model=None, base_tokenizer=None, source_prefix=None,re_history=None):
self.model = model
self.tokenizer = tokenizer
self.base_model = base_model
self.base_tokenizer = base_tokenizer
self.source_prefix = source_prefix
self.re_history = re_history
self.vectorstore_faiss = vectorstore_faiss
def _build_history(self, history=None):
if history is None:
return None
prompt=""
for i, (old_query, response) in enumerate(history):
prompt += "问:{}\n答:{}\n\n".format(old_query, response)
return prompt
def _build_prompt(self, query, history=None):
if history is None:
history = []
prompt = ""
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
return prompt
def rdf_question(self, question, history):
rdf = QuestionRDF(self.base_model)
return rdf.generate(history, question)
def chat_with_llm(self, input, isExpert=False, isEnhancement=False, isBase=False, dialog=[], temp=0.8):
run_llm = self.model if isExpert else self.base_model
history = self._build_history(self.re_history(dialog)) if self.re_history is not None else self._build_history(dialog)
similarity=None
# if history is not None and len(history) > 0:
# question = self.rdf_question(input, history)
# else:
# question = input
# print("问:", question)
question = input
if isExpert: # 专家指令,不带历史
history, similarity = None, None
prompt=PromptTemplate.from_template(prompt_expert_template)
elif isEnhancement: # 知识增强,不带历史,补充相似度
history, similarity = None, self.vectorstore_faiss._join_document(self.vectorstore_faiss.get_text_similarity_with_score(input))
if similarity and similarity.strip(): # 有补充知识
print("相似度:", similarity)
prompt=PromptTemplate.from_template(prompt_enhancement_template) if history is None else PromptTemplate.from_template(prompt_enhancement_history_template)
else: # 无补充知识,退化为专家指令
run_llm = self.model
prompt=PromptTemplate.from_template(prompt_expert_template)
# prompt=PromptTemplate.from_template(prompt_history_template) if history is not None else PromptTemplate.from_template("{question}")
else: # 普通问答,带历史
question=input
prompt=PromptTemplate.from_template(prompt_history_template) if history is not None else PromptTemplate.from_template("{question}")
chain=LLMChain(llm=run_llm, prompt=prompt,llm_kwargs={"temperature":temp})
start=time.time()
response=chain.run({"history":history,"question":question,"similarity":similarity})
cost_time=time.time()-start
print("cost:", round(cost_time, 2), "s")
return response, [], input
def chat(self, input, isExpert=False, isEnhancement=False, isBase=None, dialog=[], temp=0.8):
if isExpert:
print("专家指令",end=" ")
if self.source_prefix is not None:
history = []
prompt = self.source_prefix + input
else:
prompt = input
input = f"[专家指令]{input}"
elif isEnhancement:
print("知识增强",end=" ")
history = []
similarity = self.vectorstore_faiss._join_document(self.vectorstore_faiss.get_text_similarity_with_score(input))
if similarity is not None :
prompt=f"{similarity}\n请根据上述内容回答以下问题:\n{input}"
elif self.source_prefix is not None:
prompt = self.source_prefix + input
else:
prompt = input
input = f"[知识增强]{input}"
else:
print("普通问答",end=" ")
if self.re_history is not None:
history = self.re_history(dialog)
else:
history = dialog
prompt = input
input = f"[普通问答]{input}"
if isBase is not None and isBase:
print("基础模型")
exec_model, exec_tokenizer = self.base_model, self.base_tokenizer
else:
print("增强模型")
exec_model, exec_tokenizer = self.model, self.tokenizer
start=time.time()
# response, history = exec_model.chat(
# exec_tokenizer, prompt, history, temperature=temp)
prompt = self._build_prompt(prompt, history)
response = exec_model(prompt)
cost_time=time.time()-start
if exec_tokenizer is not None:
input_token_size = len(exec_tokenizer.encode(input))
output_token_size = len(exec_tokenizer.encode(response))
print()
print("【Itoken_size:", input_token_size, "】prompt:", prompt)
print("【Otoken_size:", output_token_size, "】response:", response)
print("cost:", round(cost_time, 2), "s",
"tps:", round((input_token_size + output_token_size / cost_time), 2),"token/s")
print("--------------------------------------------------")
else:
print("cost:", round(cost_time, 2), "s")
print("--------------------------------------------------")
return response, history, input
\ No newline at end of file
import os
import sys
import time
import requests
sys.path.append("../..")
from llm.loader import ModelLoader
from common import consts
##👇--------- config -------------
from argparse import Namespace
cfg = Namespace()
cfg.vector = True
cfg.expert = True
cfg.max_source_length = 64
cfg.max_target_length = 128
cfg.pre_seq_len = 128
#model
cfg.model_name_or_path = consts.MODEL_PATH_ChatGLM2_32K #远程'THUDM/chatglm-6b'
cfg.quantization_bit = None #仅仅预测时可以选 4 or 8
cfg.source_prefix = consts.INSTRUCTION_V1
cfg.checkout_mode = "32k" # lora or ptuning
cfg.ckpt_path = '../../../model/ckpt/chatglm2-6b-32k-qlora-INSv11-rank16-5e-4-30'
cfg.ptuning_path = '../../../model/ckpt/chatglm2-6b-32k-pt-spdsvb-INSv11-128-3e-3-3000/checkpoint-3000'
cfg.output_path='../../../exam/chatglm2-6b-32k-vector'
##👆--------- config end -------------
## --------- load model --------------
if cfg.checkout_mode == "lora":
# lora 微调 checkpoint 及模型加载
loader = ModelLoader(cfg.model_name_or_path)
loader.load_lora(cfg.ckpt_path)
elif cfg.checkout_mode == "ptuning":
# ptuning v2 微调 checkpoint 及模型加载
loader = ModelLoader(cfg.model_name_or_path, cfg.pre_seq_len, False)
loader.load_prefix(cfg.ptuning_path)
else:
loader = ModelLoader(cfg.model_name_or_path)
model,tokenizer = loader.models()
if cfg.quantization_bit is not None:
model = loader.quantize(cfg.quantization_bit)
model = model.cuda().eval()
## --------- load model end --------------
## --------- questions --------------
questions = [
"证书更新时,提示“当前证书和用户绑定非同一证书“,这么处理",
"提示:该操作需要一个智能卡,但设备中目前没有智能卡,怎么处理",
"提示:证书库中没有可用的证书,原因是啥",
"提示:多于一把USBKEY,怎么处理",
"提示:出现了内部错误,怎么处理",
"提示:无法在所请求的端口上访问web站点,怎么处理",
"提示:谷歌浏览器无法反显操作员号,网银助手显示无异常,怎么处理",
"网银助手提示:没有检测到usbkey中的证书",
"网银支持的系统及浏览器",
"网银管理员忘记密码后怎么办",
"客户想给某一操作员开通手机银行,该如何操作?",
"网银密码设置规则",
"网银初始密码如何获取",
"动账联系人如何维护",
"登录密码失败最大次数",
"在柜面给网银客户挂新账号后,在网银做交易下拉框中没有刚挂的账号,是什么原因?",
"网银做交易时提示“授权模型不匹配”,是什么原因?",
"客户登录时密码输入框提示需下载安全控件,是什么原因?",
"客户经办了一笔人民币跨行转账后,发现填写错误,想要修改或撤销该笔交易,在哪个功能下可以执行该操作?",
"客户经办了一笔人民币跨行转账后,复核人员在哪个功能下可以进行相应的复核操作?",
"客户想要对某一账号设置日累计限额和笔数,在哪个功能下可以进行相应设置",
]
import pandas as pd
data = pd.DataFrame({"id":[i+1 for i in range(len(questions))],"question":questions})
start = time.time()
# 专家指令
if cfg.expert:
responses = []
for q in questions:
print(f"Q: {q}")
prompt = cfg.source_prefix + q
print("prompt:",prompt)
response,_=model.chat(tokenizer,prompt,temperature=0)
responses.append(response)
# data.loc[data.question==q,"response"] = response
print(f"A: {response}")
print("----"*10)
data["专家指令回答"] = responses
ins_time = time.time()
print(f"1 cost time:{ins_time-start}")
data["评分1"] = [""] * len(questions)
# 知识增强
if cfg.vector:
responses = []
for q in questions:
print(f"Q: {q}")
from similarity import get_text_similarity
similarity = get_text_similarity(q)
if similarity is not None:
prompt=f"{similarity}\n请结合上述内容回答以下问题:\n{q}"
else:
prompt = cfg.source_prefix + q
print("prompt:",prompt)
response,_=model.chat(tokenizer,prompt,temperature=0)
responses.append(response)
# data.loc[data.question==q,"response"] = response
print(f"A: {response}")
print("----"*10)
data["知识增强回答"] = responses
print(f"2 cost time:{time.time()-ins_time}")
data["评分2"] = [""] * len(questions)
data.to_csv(f"{cfg.output_path}.csv",index=False)
import sys
sys.path.append("../..")
from loader import load
def extract_values(values,content, elements, extractor):
doc = "\n".join(content)
eles = extractor.extract_foreach(doc, elements)
# eles = extractor.extract(doc, elements)
for e in eles:
try:
k, v = e.split(":", maxsplit=1)
k = k.strip()
v = v.strip()
if v is not None and v != "" and v != "未知" and k in elements:
values[k] = v + "," + values[k] if k in values else v
except Exception as exp:
print(exp)
print(e)
continue
return values
def contract(extractor,file,elements,max_length):
print(file,elements,max_length)
docs = load(file)
if docs is None:
return "Error: could not load file"
print(len(docs))
content = []
content_len = 0
values={k:"" for k in elements}
for d in docs:
if content_len+len(d.page_content)>max_length:
values = extract_values(values,content, elements, extractor)
print("\n".join([f"{k}:{v}" for k,v in values.items()]))
content=[d.page_content]
content_len=len(d.page_content)
else:
content.append(d.page_content)
content_len+=len(d.page_content)
values = extract_values(values,content, elements, extractor)
# return [f"{k}:{v}" for k,v in values.items()]
return values
\ No newline at end of file
import os, sys
import re
sys.path.append("../..")
from contract.documentqa import DocumentQA,GetEmbding,GetRetriever
from llm.chatglm import ChatGLMSerLLM
from llm.ernie_with_sdk import ChatERNIESerLLM
from loader.load import load,loads_path,loads,append
from langchain.embeddings.huggingface import (
HuggingFaceEmbeddings,
)
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from contract.prompts import (
QA_PROMPT,
REFINE_QA_PROMPT,
SUMMARISE_PROMPT,
REFINE_SUMMARISE_PROMPT,
EXTRACTION_PROMPT,
REFINE_EXTRACTION_PROMPT,
ROUTER_PROMPT,
# CHAT_QUESTION_PROMPT,
# CHAT_COMBINE_PROMPT
)
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager
)
class PrintRetrievalHandler(BaseCallbackHandler):
def on_retriever_start(self, query: str, **kwargs):
print(f"**Question:** {query}")
def on_retriever_end(self, documents, **kwargs):
for idx, doc in enumerate(documents):
source = doc.metadata["source"]
print(f"**Results from {source}**")
print(doc.page_content)
base_llm=ChatGLMSerLLM(url="http://192.168.22.106:8002")
# base_llm=ChatGLMSerLLM(url="http://192.168.0.148:8000")
documentqa = DocumentQA(llm=base_llm)
def test_qa():
filepath = "/dataset/浦发硅谷电子渠道、企业网银、APIbanking、电子回单系统运维专业服务合同(4).docx"
# embedding_path = '/model/text2vec/text2vec-base-chinese'
embedding_path = '/model/moka-ai/m3e-large'
sentence_size = 1024
# retriever=GetRetriever(embdingpath=embedding_path,filepaths=[filepath],
# load_kwargs={"sentence_size":512},
# save_local=False,
# search_kwargs={"k":5})
# search_type="similarity_score_threshold",search_kwargs={"score_threshold":200.0})
# result = documentqa.summarize_document(filepaths=[filepath],load_kwargs={"sentence_size":sentence_size},chain_type="refine", chain_type_kwargs={"verbose":True},temperature=0.8)
start=time.time()
documents = loads(filepaths=[filepath],sentence_size=sentence_size*4/3)
# print("document_len",[len(doc.page_content) for doc in documents])
# print("documents size",len(documents))
# documents = append(documents=documents,sentence_size=sentence_size)
# print("document_len",[len(doc.page_content) for doc in documents])
# print("documents size",len(documents))
# documents = retriever._get_relevant_documents(query="受托人",run_manager=None)
# print([doc.page_content for doc in documents])
result = documentqa.summarize_document(documents=documents,chain_type="map_reduce", chain_type_kwargs={"token_max":sentence_size,"verbose":True})
# result = documentqa.qa_from_document(query="合同中的“受托人”名称",retriever=retriever,chain_type="map_reduce", chain_type_kwargs={"token_max":3000,"verbose":True},callbacks=[PrintRetrievalHandler()])
# result = documentqa.qa_from_document(query="合同服务期限",retriever=retriever,chain_type="refine", chain_type_kwargs={"verbose":True},callbacks=[PrintRetrievalHandler()])
# qa = RetrievalQA.from_chain_type(llm=base_llm, chain_type="map_reduce",
# chain_type_kwargs={"question_prompt":CHAT_QUESTION_PROMPT,"combine_prompt":CHAT_COMBINE_PROMPT,"token_max":3000,"verbose":True},retriever=retriever)
# result = qa.run("合同中的委托人公司名称和受托人公司名称分别是")
cost_time=time.time()-start
# print("document_len",[len(doc.page_content) for doc in documents])
# print("documents size",len(documents))
print("cost_time",cost_time)
print(result)
from vector.pgsql.db import PostgresDB
from similarity import VectorStore_FAISS
from langchain.schema import Document
def test_faiss():
vecstore_faiss = VectorStore_FAISS(
embedding_model_name='../../../model/moka-ai/m3e-large',
store_path=os.path.join(os.path.expanduser('~'),'.beai/vectorstore_enhance'),
index_name="know",
info={"port":"5432","host":"192.168.22.106","dbname":"new_vecdoc","username":"vecdoc","password":"vecdoc"},
show_number=3,
reset=True)
psqldb = PostgresDB("192.168.22.106", "vecdoc", "vecdoc", "vecdoc")
psqldb.connect()
#将当前向量库中的数据全部导出到新的向量库中
db_index = 0
page_size = 2000
print(vecstore_faiss._faiss.index.ntotal)
while True:
# query = f"SELECT text,paragraph_id FROM vec_txt order by vector_id limit %s offset %s" % (page_size,db_index)
query = f"select vec_txt.text,vec_txt.paragraph_id,txt_doc.text,count(*) as count from vec_txt left join txt_doc on vec_txt.paragraph_id=txt_doc.paragraph_id where vec_txt.text not like E'%%\n%%' and vec_txt.text != '' and vec_txt.text != ' ' and vec_txt.paragraph_id in(select DISTINCT on (text) paragraph_id from txt_doc order by text,paragraph_id) group by vec_txt.text,txt_doc.text,vec_txt.paragraph_id limit %s offset %s" % (page_size,db_index)
psqldb.execute(query)
questions = psqldb.fetchall()
if len(questions) <= 0:
break
db_index+=page_size
list_of_documents = []
for question in questions:
list_of_documents.append(Document(page_content=question[0], metadata=dict(paragraph=question[2],page=question[1])))
vecstore_faiss._add_documents(list_of_documents,need_split=False)
print(vecstore_faiss._faiss.index.ntotal)
vecstore_faiss._save_local()
from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader,UnstructuredPDFLoader,UnstructuredWordDocumentLoader
def test_faiss_from_dir():
vecstore_faiss = VectorStore_FAISS(
embedding_model_name='../../../model/moka-ai/m3e-large',
store_path=os.path.join(os.path.expanduser('~'),'.beai/vectorstore_enhance'),
index_name="know",
info={"port":5434,"host":"192.168.22.106","dbname":"vecdoc","username":"vecdoc","password":"vecdoc"},
show_number=3,
reset=True)
docs = loads_path("../../../data/docs",mode="elements",sentence_size=1024)
print(len(docs))
docs = vecstore_faiss._tuple_deduplication(docs)
print(len(docs))
print(vecstore_faiss._faiss.index.ntotal)
for i in range(0, len(docs), 300):
vecstore_faiss._add_documents(docs[i:i+300 if i+300<len(docs) else len(docs)],need_split=True)
print(vecstore_faiss._faiss.index.ntotal)
vecstore_faiss._save_local()
def test_faiss_load():
vecstore_faiss = VectorStore_FAISS(
embedding_model_name='../../../model/moka-ai/m3e-large',
store_path=os.path.join(os.path.expanduser('~'),'.beai/vectorstore_enhance'),
index_name="know",
info={"port":5432,"host":"192.168.22.106","dbname":"new_vecdoc","username":"vecdoc","password":"vecdoc"},
show_number=3,
reset=False)
print(vecstore_faiss._join_document(vecstore_faiss.get_text_similarity_with_score("通知存款支取的业务规则")))
if __name__ == "__main__":
test_faiss_from_dir()
import logging
import os, sys
sys.path.append("../..")
logging.basicConfig(filename='web.log', level=logging.INFO)
logger = logging.getLogger(__name__)
from common import consts
import gradio as gr
import mdtex2html
import torch
import transformers
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
Seq2SeqTrainingArguments,
set_seed,
)
from argparse import Namespace
from llm.loader import ModelLoader
cfg = Namespace()
#model
cfg.checkout_mode = "lora" # lora or ptuning
cfg.model_name_or_path = consts.MODEL_PATH_ChatGLM2_32K # 32K 利于知识增强模式
# cfg.model_name_or_path = '/home/zfh/models/tunning/chatglm2-6b-lora-spdsvb'
cfg.ptuning_path = '/home/zfh/aird/model/ckpt/chatglm-6b-pt-spdsvb-INSv8-128-5e-3-3000/checkpoint-3000'
cfg.ckpt_path = '/home/zfh/aird/model/ckpt/chatglm-6b-lora-spdsvb-INSv10-1e-03-20'
cfg.pre_seq_len = 0
# cfg.pre_seq_len = 128
cfg.prefix_projection = False
cfg.quantization_bit = None
cfg.source_prefix = consts.INSTRUCTION_V1
## --------- load model --------------
if cfg.checkout_mode == "lora":
# lora 微调 checkpoint 及模型加载
loader = ModelLoader(cfg.model_name_or_path)
loader.load_lora(cfg.ckpt_path)
elif cfg.checkout_mode == "ptuning":
# ptuning v2 微调 checkpoint 及模型加载
loader = ModelLoader(cfg.model_name_or_path, cfg.pre_seq_len, False)
loader.load_prefix(cfg.ptuning_path)
model,tokenizer = loader.models()
if cfg.quantization_bit is not None:
model = loader.quantize(cfg.quantization_bit)
model = model.cuda().eval()
## --------- load model end --------------
"""Override Chatbot.postprocess"""
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>"+line
text = "".join(lines)
return text
# def predict(input, chatbot, max_length, top_p, temperature, enhance, instruct,history):
def predict(input, chatbot, max_length, enhance, instruct, history):
# chatbot.append((parse_text(input), ""))
if enhance:
print("知识增强")
history = []
from similarity import get_text_similarity
similarity = get_text_similarity(input)
print("similarity:",similarity)
if similarity is not None :
prompt=f"{similarity}\n请结合上述内容回答以下问题:\n{input}"
elif cfg.source_prefix is not None and instruct:
prompt = cfg.source_prefix + input
else:
prompt = f"对不起,我没有找到相关内容\n请结合上述内容回答以下问题:\n{input}"
input = f"[知识增强]{input}"
elif instruct:
print("专家指令")
if cfg.source_prefix is not None and instruct:
history = []
prompt = cfg.source_prefix + input
else:
prompt = input
input = f"[专家指令]{input}"
else:
print("普通问答")
prompt = input
input = f"[普通问答]{input}"
chatbot.append((input, ""))
logger.info(f"prompt: {prompt}")
# for response, history in model.stream_chat(tokenizer, prompt, history, max_length=max_length, top_p=top_p, # type: ignore
# temperature=temperature):
for response, history in model.stream_chat(tokenizer, prompt, history, max_length=max_length, # type: ignore
):
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history
def reset_user_input():
return gr.update(value='')
def reset_state():
return [], []
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM - SPDSVB</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 32768, value=2048, step=100.0, label="Maximum length", interactive=True)
if cfg.model_name_or_path == consts.MODEL_PATH_ChatGLM2_32K:
max_length.value = 32768
# top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
# temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
with gr.Row():
enhance = gr.Checkbox(label="知识增强", interactive=True)
instruct = gr.Checkbox(label="专家指令",value=True, interactive=True)
history = gr.State([])
# submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, enhance, instruct, history], [chatbot, history],
submitBtn.click(predict, [user_input, chatbot, max_length, enhance, instruct, history], [chatbot, history],
show_progress="minimal")
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress='minimal')
def main():
demo.queue().launch(share=True,server_port=9999)
if __name__ == "__main__":
main()
\ No newline at end of file
import os, sys
sys.path.append("../..")
import json
import pandas
from llm.spark import SparkLLM
from llm.ernie import ErnieLLM, ModelType
from qa.generator import QAGenerator as QAGeneratorBase
from qa.generator import TrainData
from langchain.chat_models import ChatOpenAI
from langchain.base_language import BaseLanguageModel
class QAGenerator:
def __init__(self, llm: BaseLanguageModel):
"""
初始化 QAGenerator
:param llm: 语言模型
"""
self.generator = TrainData(llm=llm)
def generate_questions_and_answers(self, input_text, question_number=3):
"""
生成问题和答案
:param input_text: 输入文本
:param question_number: 生成问题的数量,默认为 3
:return: 问题和答案的列表
"""
# questions = self.generator.generate_questions(input_text, question_number=question_number)
# result = []
# for q in questions:
# answer = self.generator.generate_answer(input_text, q)
# result.append((q, answer))
result = self.generator.generate(input_text, question_number=question_number)
return result
def read_and_deduplicate_text(self, data_path, text_column):
"""
读取指定路径下的文件并返回去重后的文本列表
:param data_path: 文件路径
:param text_column: 文本所在列的列名
:return: 去重后的文本列表
"""
with open(data_path, "r", encoding="utf-8") as f:
text_list=[]
ext = data_path.split(".")[-1]
if ext == "json":
data = json.load(f)
for item in data:
text_list.append(item[text_column])
elif ext == "csv":
data = pandas.read_csv(f)
for item in data[text_column]:
text_list.append(item)
# 去重
text_list = list(set(text_list))
print(f"共读取到 {len(text_list)} 条文本")
return text_list
def save_questions_and_answers(self, df, output_path):
"""
将问题和答案保存到 CSV 文件中,已追加的方式
:param df: 问题和答案的 DataFrame
:param output_path: 输出文件路径
"""
if os.path.exists(output_path):
df.to_csv(output_path, mode="a", index=False, header=False)
else:
df.to_csv(output_path, index=False)
def generate_questions_csv(self, data_path, text_column, output_path, save_step=100, max_num=1000):
"""
生成问题和答案的 CSV 文件
:param data_path: 数据文件路径
:param text_column: 文本所在列的列名
:param output_path: 输出文件路径
"""
# 读取数据文件并去重
text_list = self.read_and_deduplicate_text(data_path, text_column)
text_list=text_list[:max_num]
# 生成问题和答案
result = []
for i, input in enumerate(text_list):
if len(input) > 200:
num_questions = 5
else:
num_questions = 3
questions = self.generate_questions_and_answers(input, question_number=num_questions)
for q, a in questions:
result.append((input, q, a))
if (i+1) % save_step == 0:
questions_df = pandas.DataFrame(result, columns=["text", "question", "answer"])
self.save_questions_and_answers(questions_df, output_path)
result = []
# 打印进度条
progress = (i + 1) / len(text_list) * 100
bar_length = 50
filled_length = int(bar_length * progress // 100)
bar = "█" * filled_length + "-" * (bar_length - filled_length)
print(f"\r已处理 {i+1}/{len(text_list)} 条数据 |{bar}| {progress:.2f}%", end="")
# 将结果保存到 CSV 文件中
questions_df = pandas.DataFrame(result, columns=["text", "question", "answer"])
self.save_questions_and_answers(questions_df, output_path)
# questions_df = pandas.DataFrame(result, columns=["text", "question", "answer"])
# questions_df.to_csv(output_path, index=False)
import os, sys
sys.path.append("../..")
from dotenv import load_dotenv,find_dotenv
load_dotenv(find_dotenv())
from llm.spark import SparkLLM
from llm.chatglm import ChatGLMSerLLM
from qa.generator import QAGenerator,TrainData
from langchain.chat_models import ChatOpenAI
# llm=SparkLLM()
# llm=ChatOpenAI(model_name="gpt-3.5-turbo")
# llm=ChatOpenAI(model_name="gpt-4")
# llm=ChatGLMSerLLM(url="http://localhost:8002")
from llm.ernie import ErnieLLM, ModelType
llm = ErnieLLM(model_name=ModelType.ERNIE_LITE)
generator=QAGenerator(llm=llm)
# input='''下载网银相关软件分为一下几个步骤
# 1.使用浏览器打开浦发硅谷银行首页:https://www.spd-svbank.com/cn/
# 2.点击右上角“网上银行”登录
# 3.在登录页“登录”按钮下方,点击“下载网银软件”,跳转至软件下载界面
# 4.选择需要下载的软件点击下载按钮即可
# #Windows:
# ##Firefox /IE 10/IE 11/Edge 浏览器:安装 Firefox 扩展,网银管家,USBKEY 驱动, 密码控件,签名和证书控件。
# ##Chrome 浏览器:安装 Chrome 扩展,网银管家,USBKEY 驱动,密码控件,签名和证书控件。
# #Mac
# ##Firefox/Chrome浏览器:安装网银管家,USBKEY 驱动,密码控件,签名和证书控件。
# (为保证您正常使用我行网上银行,使用网银 USBKEY 的客户请先安装我行网银管家安全组件。我行网银管家安全组件包括系统环境、网银控件、IE 设置等安全检测及网银USBKEY 管理工具,可一次性完成网银所需的所有控件及驱动程序的安装。)
# '''
input='''季度对账的页面描述,1、进入银企对账账户列表页面,点击银企对账,进入对账单列表页面,点击对账提交,进入对账单明细页面,选择对账状态,若为对账不相符,对账疑义为必输字段,若为对账相符,对账疑义可不输。
2、进入银企对账账户列表页面,点击对账详情,进入对账单列表页面,点击对账详情,进入对账单明细页面,可选择打印对账单明细数据。
3、对账单列表和对账单明细可下载pdf文件和excel文档。
'''
result=[]
# questions = generator.generate_questions(input,question_number=3)
# print("questions:",questions)
# print("---"*10)
# for q in questions:
# answer=generator.generate_answer(input,q)
# print("Q:",q,"\nA:",answer)
# print("----"*10)
train_data_gen=TrainData(llm=llm)
train_data=train_data_gen.generate(input,question_number=3)
print(train_data)
\ No newline at end of file
from dotenv import load_dotenv
load_dotenv()
from qagenerator import QAGenerator
# from langchain.chat_models import ChatOpenAI
# llm=ChatOpenAI(model_name="gpt-3.5-turbo")
from llm.ernie import ErnieLLM, ModelType
llm = ErnieLLM(model_name=ModelType.ERNIE_LITE)
input_file="../../../data/knowledge_qa.csv"
output_file="../../../data/kownledge_questions_qa_t50.ernie.csv"
qa_generator_csv = QAGenerator(llm=llm)
qa_generator_csv.generate_questions_csv(input_file, "input", output_file, max_num=50, save_step=10)
\ No newline at end of file
import os, sys
sys.path.append("../..")
import logging
logging.basicConfig(filename='web.log', level=logging.INFO)
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
load_dotenv()
from common import consts
import gradio as gr
import mdtex2html
from llm.spark import SparkLLM
from langchain.chat_models import ChatOpenAI
from qa.generator import QAGenerator
llm=SparkLLM()
generator=QAGenerator(llm=llm)
def parse_text(text):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>"+line
text = "".join(lines)
return text
# def predict(input, chatbot, max_length, top_p, temperature, enhance, instruct,history):
def predict(input, chatbot, max_length, history):
# chatbot.append((parse_text(input), ""))
result=[]
questions = generator.generate_questions(input,question_number=3)
# print("questions:",questions)
for q in questions:
answer=generator.generate_answer(input,q)
result.append((q,answer))
# print("Q:",q,"A:",answer)
response = "\n\n".join([f"{q}\n{a}" for q,a in result])
# print("response:",response)
chatbot.append((parse_text(response),""))
return chatbot, history
def reset_user_input():
return gr.update(value='')
def reset_state():
return [], []
def change_model(model_name):
global llm, generator
if model_name == "spark":
llm=SparkLLM()
elif model_name == "openai":
llm=ChatOpenAI(model_name="gpt-3.5-turbo")
generator=QAGenerator(llm=llm)
return model_name
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">训练集提取</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
model_name = gr.Radio(choices=["spark", "openai"], label="Model", value="spark", inline=True,interactive=True)
max_length = gr.Slider(0, 32768, value=2048, step=100.0, label="Maximum length", interactive=True)
emptyBtn = gr.Button("Clear History")
# top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
# temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
model_name.change(change_model,[model_name],[model_name])
history = gr.State([])
# submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, enhance, instruct, history], [chatbot, history],
submitBtn.click(predict, [user_input, chatbot, max_length, history], [chatbot, history],
show_progress="minimal")
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress='minimal')
def main():
demo.queue().launch(share=True,server_port=9999)
if __name__ == "__main__":
main()
\ No newline at end of file
import json
import matplotlib.pyplot as plt
import numpy as np
# 加载 json 文件
# 读取其中的 log_history 列表数据
# 基于 log_history 列表数据,绘制 loss 曲线图
def load_json(file_path):
with open(file_path, 'r') as f:
data = json.load(f)
return data
dataset_chatglm = [
{"chatglm1-6b-spdsvb":"/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/chatglm-6b-spdsvb/eval_results.json"},
# {"chatglm-6b-pt-spdsvb-128-1e-3-3000-v1":'/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-spdsvb-128-1e-3-3000/eval_results.json'},
# {"chatglm-6b-pt-spdsvb-128-1e-3-3000-v2":'/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-spdsvb-128-1e-3-3000-v2/eval_results.json'},
{"chatglm-6b-pt-spdsvb-128-5e-3-3000-base":"/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-128-5e-3-3000-spdsvb-base/eval_results.json"},
{"chatglm-6b-lora-spdsvb-base":"/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/chatglm-6b-lora-spdsvb-base/eval_results.json"},
{"chatglm-6b-lora-spdsvb-base-5e-3-50":"/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/chatglm-6b-lora-spdsvb-base-5e-3-50/eval_results.json"},
{"chatglm-6b-lora-spdsvb-INSv4-1e-03-50":"/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/chatglm-6b-lora-spdsvb-INSv4-1e-03-50/eval_results.json"},
{"chatglm-qlora-t1":"/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/chatglm-qlora-t1/eval_results.json"},
{"chatglm-6b-qlora-spdsvb-v4_t32":"/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/chatglm-6b-qlora-spdsvb-v4_t32/eval_results.json"},
]
labels = [
"base",
# "pt-v1",
# "pt-v2",
"pt-base",
"alora-base",
"alora-5e3-base",
"alora-5e3-v4",
"qlora-base",
"qlora-v4",
]
dataset_chatglm2 = [
{"chatglm2-6b-spdsvb":"/home/zfh/ChatGLM/ChatGLM2-6B/ptuning/output/chatglm2-6b/spdsvb-base/eval_results.json"},
{"chatglm2-6b-pt-spdsvb-base":"/home/zfh/ChatGLM/ChatGLM2-6B/ptuning/output/adgen-chatglm2-6b-pt-128-5e-4-spdsvb-base/eval_results.json"},
{"chatglm2-6b-lora-spdsvb":"/home/zfh/models/tunning/chatglm2-6b-lora-spdsvb/eval_results.json"},
{"chatglm2-6b-lora-spdsvb-v3":"/home/zfh/models/tunning/chatglm2-6b-spdb-v3/eval_results.json"},
]
labels2 = [
"2-base",
"2-pt-base",
"2-lora-v1",
"2-lora-v3",
]
def bar_eval(x,width,data,label):
y = data
plt.bar(x, y, width=width, label=label)
plt.plot(x, y, '-o')
plt.legend()
def create_plot(dataset,labels,name):
plt.figure(figsize=(10, 4))
bleu_4,rouge_1,rouge_2,rouge_l = [],[],[],[]
for item in dataset:
for key,value in item.items():
eval_data = load_json(value)
bleu_4.append(eval_data['eval_bleu-4'])
rouge_1.append(eval_data['eval_rouge-1'])
rouge_2.append(eval_data['eval_rouge-2'])
rouge_l.append(eval_data['eval_rouge-l'])
x=np.arange(len(labels))
plt.xticks(x,labels)
width = 0.1
bar_eval(x-1.5*width,width,bleu_4,'bleu-4')
bar_eval(x-0.5*width,width,rouge_1,'rouge-1')
bar_eval(x+0.5*width,width,rouge_2,'rouge-2')
bar_eval(x+1.5*width,width,rouge_l,'rouge-l')
plt.title("chatglm-6b spdsvb eval")
plt.savefig(f'../../images/eval/{name}.png')
create_plot(dataset_chatglm,labels,'chatglm')
create_plot(dataset_chatglm2,labels2,'chatglm2')
\ No newline at end of file
import json
import matplotlib.pyplot as plt
# 加载 json 文件
# 读取其中的 log_history 列表数据
# 基于 log_history 列表数据,绘制 loss 曲线图
def load_json(file_path):
with open(file_path, 'r') as f:
data = json.load(f)
return data['log_history']
def plot_loss(log_history,label):
x,y=[],[]
ex,ey=[],[]
# for i in range(len(log_history)):
for item in log_history:
if 'loss' in item:
x.append(item["step"])
y.append(item["loss"])
elif 'eval_loss' in item:
ex.append(item["step"])
ey.append(item["eval_loss"])
else:
# print('Error: loss key not found in item:', item)
continue
plt.plot(x, y,'-o', label=label+"train_loss")
if len(ey)>0:
plt.plot(ex, ey,'-o', label=label+"eval_loss")
# plt.scatter(x, y, label='train_loss')
plt.legend()
# plt.show(block=True)
def create_plot(dataset,name):
plt.figure(figsize=(10, 4))
for item in dataset:
for key,value in item.items():
log_history = load_json(value)
plot_loss(log_history,key)
plt.savefig(f'../../images/loss/{name}.png')
def create_plot_one(state_file):
plt.figure(figsize=(10, 4))
log_history = load_json(state_file)
dir = '/'.join(state_file.split('/')[:-1])
plot_loss(log_history,"")
plt.savefig(f'{dir}/loss.png')
dataset_chatglm = [
# {"adgen-chatglm-6b-pt-comb-128-2e-2":'/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-comb-128-2e-2/trainer_state.json'},
# {"adgen-chatglm-6b-pt-comb-128-2e-2-3000":'/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-comb-128-2e-2-3000/trainer_state.json'},
# {"adgen-chatglm-6b-pt-comb-128-5e-3-1000":'/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-comb-128-5e-3-1000/trainer_state.json'},
# {"adgen-chatglm-6b-pt-comb-128-1e-3-3000":'/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-comb-128-1e-3-3000/trainer_state.json'},
{"adgen-chatglm-6b-pt-spdsvb-128-1e-2-3000":'/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-spdsvb-128-1e-3-3000/trainer_state.json'},
{"adgen-chatglm-6b-pt-spdsvb-128-1e-3-3000-v2":'/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-spdsvb-128-1e-3-3000-v2/trainer_state.json'},
{"adgen-chatglm-6b-pt-spdsvb-128-1e-3-3000-v3":'/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-spdsvb-128-1e-3-3000-v3/trainer_state.json'},
{"adgen-chatglm-6b-pt-spdsvb-128-5e-3-3000-base":'/home/zfh/ChatGLM/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-128-5e-3-3000-spdsvb-base/trainer_state.json'},
# {"adgen-chatglm2-6b-pt-128-5e-4-spdsvb-v3":"/home/zfh/ChatGLM/ChatGLM2-6B/ptuning/output/adgen-chatglm2-6b-pt-128-5e-4-spdsvb-v3/trainer_state.json"},
# {"adgen-chatglm2-6b-pt-128-5e-4-spdsvb-base":"/home/zfh/ChatGLM/ChatGLM2-6B/ptuning/output/adgen-chatglm2-6b-pt-128-5e-4-spdsvb-base/trainer_state.json"},
{"chatglm-6b-pt-spdsvb-INSv8-128-5e-3-3000":"/home/zfh/aird/model/ckpt/chatglm-6b-pt-spdsvb-INSv8-128-5e-3-3000/trainer_state.json"},
{"chatglm-6b-pt-spdsvb-INSv9-128-5e-3-3000":"/home/zfh/aird/model/ckpt/chatglm-6b-pt-spdsvb-INSv9-128-5e-3-3000/trainer_state.json"},
{"chatglm-6b-pt-spdsvb-INSv10-128-5e-3-3000":"/home/zfh/aird/model/ckpt/chatglm-6b-pt-spdsvb-INSv10-128-5e-3-3000/trainer_state.json"},
]
dataset_v11 = [
{"chatglm2-6b-pt-spdsvb-INSv11-128-1e-3-3000":"/home/zfh/aird/model/ckpt/chatglm2-6b-pt-spdsvb-INSv11-128-1e-3-3000/trainer_state.json"},
{"chatglm2-6b-pt-spdsvb-INSv11-128-5e-3-3010":"/home/zfh/aird/model/ckpt/chatglm2-6b-pt-spdsvb-INSv11-128-5e-3-3010/trainer_state.json"},
{"chatglm2-6b-qlora-INSv11-rank16-1e-3-30":"/home/zfh/aird/model/ckpt/chatglm2-6b-qlora-INSv11-rank16-1e-3-30/checkpoint-2000/trainer_state.json"},
{"chatglm2-6b-pt-spdsvb-INSv11-128-3e-3-1000+s1000":"/home/zfh/aird/model/ckpt/chatglm2-6b-pt-spdsvb-INSv11-128-3e-3-1000+s1000/trainer_state.json"},
{"chatglm2-6b-32k-qlora-INSv11-rank16-5e-4-30":"/home/zfh/aird/model/ckpt/chatglm2-6b-32k-qlora-INSv11-rank16-5e-4-30/checkpoint-2800/trainer_state.json"},
{"chatglm2-6b-32k-pt-spdsvb-INSv11-128-3e-3-3000":"/home/zfh/aird/model/ckpt/chatglm2-6b-32k-pt-spdsvb-INSv11-128-3e-3-3000/checkpoint-3000/trainer_state.json"}
]
dataset_lora = [
# {"chatglm-6b-lora-spdsvb-base":'/home/zfh/aird/model/ckpt/chatglm-6b-lora-spdsvb-base/train_history.csv'},
{"chatglm2-6b-qlora-INSv11_rank16-1e-3-30":"/home/zfh/aird/model/ckpt/chatglm2-6b-qlora-INSv11_rank16-1e-3-30/checkpoint-2000/trainer_state.json"},
]
create_plot(dataset_chatglm,'spdsvb')
# create_plot(dataset_chatglm2,'chatglm2')
for item in dataset_chatglm:
for key,value in item.items():
create_plot_one(value)
create_plot(dataset_v11,'spdsvb_v11')
for item in dataset_v11:
for key,value in item.items():
create_plot_one(value)
\ No newline at end of file
import csv
import json
import sys
sys.path.append("..")
from common import consts
def data_format(csv_file, jsonl_file):
data = []
with open(csv_file, 'r',encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
prompt = row['prompt']
response = row['response']
data.append({'prompt': prompt, 'response': response})
with open(jsonl_file, 'w',encoding='utf-8') as f:
for d in data:
f.write(json.dumps(d, ensure_ascii=False) + '\n')
csv_file = '../../data/train_comb.csv'
jsonl_file = '../../data/sheet_train_comb.json'
data_format(csv_file, jsonl_file)
csv_file = '../../data/train_comb_eval.csv'
jsonl_file = '../../data/sheet_train_comb_eval.json'
data_format(csv_file, jsonl_file)
import csv
import json
import sys
sys.path.append("..")
from common import consts
csv_file = '../../data/train_spdsvb_v11.csv'
jsonl_file = '../../data/train_spdsvb_v11.jsonl'
data = []
with open(csv_file, 'r',encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
prompt = consts.INSTRUCTION_V1+row['prompt']
response = [[row['response']]]
data.append([{'prompt': prompt, 'response': response}])
with open(jsonl_file, 'w',encoding='utf-8') as f:
for d in data:
f.write(json.dumps(d, ensure_ascii=False) + '\n')
\ No newline at end of file
#!/bin/bash
DATA_TRAIN_FILE='../../../data/train_spdsvb_v10.csv'
DATA_VAL_FILE='../../../data/val_spdsvb_v4.csv'
PROMPT_PREFIX="你是浦发硅谷银行网银系统的专家,请帮助解答用户在使用过程中遇到的问题。"
MODEL_PATH_ChatGLM="/home/zfh/models/chatglm-6b"
MODEL_PATH_ChatGLM2="/home/zfh/models/chatglm2-6b"
MODEL_PATH_ChatGLM2_32K="/home/zfh/models/chatglm2-6b-32k"
MODEL_NAME_ChatGLM="THUDM/chatglm-6b"
MODEL_NAME_ChatGLM2="THUDM/chatglm2-6b"
INSTRUCTION_V1="你是浦发硅谷银行网银系统的专家,请帮助解答用户在使用过程中遇到的问题。\n"
\ No newline at end of file
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment