import os, copy

from langchain_community.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader, UnstructuredPDFLoader, \
    UnstructuredWordDocumentLoader, PDFMinerPDFasHTMLLoader

from .config import SENTENCE_SIZE, ZH_TITLE_ENHANCE
from .chinese_text_splitter import ChineseTextSplitter
from .zh_title_enhance import zh_title_enhance
from langchain_core.documents import Document
from typing import List
from src.loader.callback import BaseCallback
import re
from bs4 import BeautifulSoup


def load(filepath, mode: str = None, sentence_size: int = 0, metadata=None, callbacks=None, **kwargs):
    r"""
        加载文档,参数说明
        mode:文档切割方式,"single", "elements", "paged"
        sentence_size:对于较大的document再次切割成多个
        kwargs
    """
    if filepath.lower().endswith(".md"):
        loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
    elif filepath.lower().endswith(".txt"):
        loader = TextLoader(filepath, autodetect_encoding=True, **kwargs)
    elif filepath.lower().endswith(".csv"):
        loader = CSVLoader(filepath, **kwargs)
    elif filepath.lower().endswith(".pdf"):
        # loader = UnstructuredPDFLoader(filepath, mode=mode or "elements",**kwargs)
        # 使用自定义pdf loader
        return __pdf_loader(filepath, sentence_size=sentence_size, metadata=metadata, callbacks=callbacks)
    elif filepath.lower().endswith(".docx") or filepath.lower().endswith(".doc"):
        loader = UnstructuredWordDocumentLoader(filepath, mode=mode or "elements", **kwargs)
    else:
        loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
    if sentence_size > 0:
        try:
            return split(loader.load(), sentence_size)
        except:
            print(filepath, " is wrong ")
            return []
    return loader.load()


def loads_path(path: str, **kwargs):
    return loads(get_files_in_directory(path), **kwargs)


def loads(filepaths, **kwargs):
    default_kwargs = {"mode": "paged"}
    default_kwargs.update(**kwargs)
    documents = [load(filepath=file, **default_kwargs) for file in filepaths]
    return [item for sublist in documents for item in sublist]


def append(documents=None, sentence_size: int = SENTENCE_SIZE):  # 保留文档结构信息,注意处理hash
    if documents is None:
        documents = []
    effect_documents = []
    last_doc = documents[0]
    for doc in documents[1:]:
        last_hash = "" if "next_hash" not in last_doc.metadata else last_doc.metadata["next_hash"]
        doc_hash = "" if "next_hash" not in doc.metadata else doc.metadata["next_hash"]
        if len(last_doc.page_content) + len(doc.page_content) <= sentence_size and last_hash == doc_hash:
            last_doc.page_content = last_doc.page_content + doc.page_content
            continue
        else:
            effect_documents.append(last_doc)
            last_doc = doc
    effect_documents.append(last_doc)
    return effect_documents


def split(documents=None, sentence_size: int = SENTENCE_SIZE):  # 保留文档结构信息,注意处理hash
    if documents is None:
        documents = []
    effect_documents = []
    for doc in documents:
        if len(doc.page_content) > sentence_size:
            words_list = re.split(r'·-·', doc.page_content.replace("。", "。·-·").replace("\n", "\n·-·"))  # 插入分隔符,分割
            document = Document(page_content="", metadata=copy.deepcopy(doc.metadata))
            first = True
            for word in words_list:
                if len(document.page_content) + len(word) < sentence_size:
                    document.page_content += word
                else:
                    if len(document.page_content.replace(" ", "").replace("\n", "")) > 0:
                        if first:
                            first = False
                        else:

                            effect_documents[-1].metadata["next_doc"] = document.page_content
                        effect_documents.append(document)
                    document = Document(page_content=word, metadata=copy.deepcopy(doc.metadata))
            if len(document.page_content.replace(" ", "").replace("\n", "")) > 0:
                if first:
                    pass
                else:
                    effect_documents[-1].metadata["next_doc"] = document.page_content
                effect_documents.append(document)
        else:
            effect_documents.append(doc)
    return effect_documents


def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE, mode: str = None,
              **kwargs):
    print("load_file", filepath)
    if filepath.lower().endswith(".md"):
        loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
        docs = loader.load()
    elif filepath.lower().endswith(".txt"):
        loader = TextLoader(filepath, autodetect_encoding=True, **kwargs)
        textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
        docs = loader.load_and_split(textsplitter)
    elif filepath.lower().endswith(".csv"):
        loader = CSVLoader(filepath, **kwargs)
        docs = loader.load()
    elif filepath.lower().endswith(".pdf"):
        loader = UnstructuredPDFLoader(filepath, mode=mode or "elements", **kwargs)
        textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
        docs = loader.load_and_split(textsplitter)
    elif filepath.lower().endswith(".docx"):
        loader = UnstructuredWordDocumentLoader(filepath, mode=mode or "elements", **kwargs)
        textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
        docs = loader.load_and_split(textsplitter)
    else:
        loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
        textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
        docs = loader.load_and_split(text_splitter=textsplitter)
    if using_zh_title_enhance:
        docs = zh_title_enhance(docs)
    write_check_file(filepath, docs)
    return docs


def write_check_file(filepath, docs):
    folder_path = os.path.join(os.path.dirname(filepath), "tmp_files")
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    fp = os.path.join(folder_path, 'load_file.txt')
    with open(fp, 'a+', encoding='utf-8') as fout:
        fout.write("filepath=%s,len=%s" % (filepath, len(docs)))
        fout.write('\n')
        for i in docs:
            fout.write(str(i))
            fout.write('\n')
        fout.close()


def get_files_in_directory(directory):
    file_paths = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)
            file_paths.append(file_path)
    return file_paths


# 自定义pdf load部分
def __checkV(strings: str):
    lines = len(strings.splitlines())
    if lines > 3 and len(strings.replace(" ", "")) / lines < 15:
        return False
    return True


def __isTitle(strings: str):
    return len(strings.splitlines()) == 1 and len(strings) > 0 and strings.endswith("\n")


def __appendPara(strings: str):
    return strings.replace(".\n", "^_^").replace("。\n", "^-^").replace("?\n", "?^-^").replace("?\n", "?^-^").replace(
        "\n", "").replace("^_^", ".\n").replace("^-^", "。\n").replace("?^-^", "?\n").replace("?^-^", "?\n")


def __check_fs_ff(line_ff_fs_s, fs, ff):  # 若当前行有上一行一样的字体、字号文字,则返回相同的。默认返回最长文本的字体和字号
    re_fs = line_ff_fs_s[-1][0][-1]
    re_ff = line_ff_fs_s[-1][1][-1] if line_ff_fs_s[-1][1] else None
    max_len = 0
    for ff_fs in line_ff_fs_s:  # 寻找最长文本字体和字号
        c_max = max(list(map(int, ff_fs[0])))
        if max_len < ff_fs[2] or (max_len == ff_fs[2] and c_max > int(re_fs)):
            max_len = ff_fs[2]
            re_fs = c_max
            re_ff = ff_fs[1][-1] if ff_fs[1] else None
    if fs:
        for ff_fs in line_ff_fs_s:
            if str(fs) in ff_fs[0] and ff in ff_fs[1]:
                re_fs = fs
                re_ff = ff
                break
    return int(re_fs), re_ff


def append_document(snippets1: List[Document], title: str, content: str, callbacks, font_size, page_num, metadate,
                    need_append: bool = False):
    if callbacks:
        for cb in callbacks:
            if isinstance(cb, BaseCallback):
                if cb.filter(title, content):
                    return
    if need_append and len(snippets1) > 0:
        ps = snippets1.pop()
        snippets1.append(Document(page_content=ps.page_content + title, metadata=ps.metadata))
    else:
        doc_metadata = {"font-size": font_size, "page_number": page_num}
        doc_metadata.update(metadate)
        snippets1.append(Document(page_content=title + content, metadata=doc_metadata))


'''
    提取pdf文档,按标题和内容进行分割,文档的页码按标题所在页码为准
    分割后的文本按sentence_size值再次分割,分割的文本的页码均属于父文本的页码
'''


def __pdf_loader(filepath: str, sentence_size: int = 0, metadata=None, callbacks=None):
    if not filepath.lower().endswith(".pdf"):
        raise ValueError("file is not pdf document")
    loader = PDFMinerPDFasHTMLLoader(filepath)
    documents = loader.load()
    soup = BeautifulSoup(documents[0].page_content, 'html.parser')
    content = soup.find_all('div')
    cur_fs = None  # 当前文本font-size
    last_fs = None  # 上一段文本font-size
    cur_ff = None  # 当前文本风格
    cur_text = ''
    fs_increasing = False  # 下一行字体变大,判断为标题,从此处分割
    last_text = ''
    last_page_num = 1  # 上一页页码 根据page_split判断当前文本页码
    page_num = 1  # 初始页码
    page_change = False  # 页面切换
    page_split = False  # 页面是否出现文本分割
    last_is_title = False  # 上一个文本是否是标题
    snippets: List[Document] = []

    filename = os.path.basename(filepath)
    if metadata:
        metadata.update({'source': filepath, 'filename': filename, 'filetype': 'application/pdf'})
    else:
        metadata = {'source': filepath, 'filename': filename, 'filetype': 'application/pdf'}
    for c in content:
        divs = c.get('style')
        if re.match(r"^(Page|page)", c.text):  # 检测当前页的页码
            match = re.match(r"^(page|Page)\s+(\d+)", c.text)
            if match:
                if page_split:  # 如果有文本分割,则换页,没有则保持当前文本起始页码
                    last_page_num = page_num
                page_num = match.group(2)
                if len(last_text) + len(cur_text) == 0:  # 如果翻页且文本为空,上一页页码为当前页码
                    last_page_num = page_num
                page_change = True
                page_split = False
            continue
        if re.findall('writing-mode:(.*?);', divs) == ['False'] or re.match(r'^[0-9\s\n]+$', c.text) or re.match(
                r"^第\s+\d+\s+页$", c.text):  # 如果不显示或者纯数字
            continue
        if len(c.text.replace("\n", "").replace(" ", "")) <= 1:  # 去掉有效字符小于1的行
            continue
        sps = c.find_all('span')
        if not sps:
            continue
        line_ff_fs_s = []  # 有效字符大于1的集合
        line_ff_fs_s2 = []  # 有效字符为1的集合
        for sp in sps:  # 如果一行中有多个不同样式的
            sp_len = len(sp.text.replace("\n", "").replace(" ", ""))
            if sp_len > 0:
                st = sp.get('style')
                if st:
                    ff_fs = (re.findall('font-size:(\d+)px', st), re.findall('font-family:(.*?);', st),
                             len(sp.text.replace("\n", "").replace(" ", "")))
                    if sp_len == 1:  # 过滤一个有效字符的span
                        line_ff_fs_s2.append(ff_fs)
                    else:
                        line_ff_fs_s.append(ff_fs)

        if len(line_ff_fs_s) == 0:  # 如果为空,则以一个有效字符span为准
            if len(line_ff_fs_s2) > 0:
                line_ff_fs_s = line_ff_fs_s2
            else:
                if len(c.text) > 0:
                    page_change = False
                continue
        fs, ff = __check_fs_ff(line_ff_fs_s, cur_fs, cur_ff)
        if not cur_ff:
            cur_ff = ff
        if not cur_fs:
            cur_fs = fs

        if abs(fs - cur_fs) <= 1 and ff == cur_ff:  # 风格和字体都没改变
            cur_text += c.text
            cur_fs = fs
            page_change = False
            if len(cur_text.splitlines()) > 3:  # 连续多行则fs_increasing不再生效
                fs_increasing = False
        else:
            if page_change and cur_fs > fs + 1:  # 翻页,(字体变小)  大概率是页眉,跳过c.text。-----有可能切掉一行文本
                page_change = False
                continue
            if last_is_title:  # 如果上一个为title
                if __isTitle(cur_text) or fs_increasing:  # 连续多个title 或者 有变大标识的
                    last_text = last_text + cur_text
                    last_is_title = True
                    fs_increasing = False
                else:
                    append_document(snippets, last_text, __appendPara(cur_text), callbacks, cur_fs,
                                    page_num if page_split else last_page_num, metadata)
                    page_split = True
                    last_text = ''
                    last_is_title = False
                    fs_increasing = int(fs) > int(cur_fs)  # 字体变大
            else:
                if len(last_text) > 0 and __checkV(last_text):  # 过滤部分文本
                    # 将跨页的两段或者行数较少的文本合并
                    append_document(snippets, __appendPara(last_text), "", callbacks, last_fs,
                                    page_num if page_split else last_page_num, metadata,
                                    need_append=len(last_text.splitlines()) <= 2 or page_change)
                    page_split = True
                last_text = cur_text
                last_is_title = __isTitle(last_text) or fs_increasing
                fs_increasing = int(fs) > int(cur_fs)
            if page_split:
                last_page_num = page_num
            last_fs = cur_fs
            cur_fs = fs
            cur_ff = ff
            cur_text = c.text
            page_change = False
    append_document(snippets, last_text, __appendPara(cur_text), callbacks, cur_fs,
                    page_num if page_split else last_page_num, metadata)
    if sentence_size > 0:
        return split(snippets, sentence_size)
    return snippets