from .c_db import UPostgresDB

TABLE_CHAT = """
DROP TABLE IF EXISTS "chat";
CREATE TABLE chat (
    chat_id varchar(1000) PRIMARY KEY,
    user_id varchar(1000),
    info text,
    create_time timestamp(6) DEFAULT current_timestamp,
    deleted int2
);
COMMENT ON COLUMN "chat"."chat_id" IS '会话id';
COMMENT ON COLUMN "chat"."user_id" IS '会话创建用户id';
COMMENT ON COLUMN "chat"."info" IS '会话简介';
COMMENT ON COLUMN "chat"."create_time" IS '会话创建时间,默认为当前时间';
COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表';
DROP SEQUENCE IF EXISTS "chat_seq";
CREATE SEQUENCE "chat_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE chat ALTER COLUMN chat_id SET DEFAULT nextval('chat_seq'::regclass);
"""

TABLE_TURN_QA = """
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
    turn_id varchar(1000) PRIMARY KEY,
    chat_id varchar(1000),
    question text,
    answer text,
    similar_docs text,
    create_time timestamp(6) DEFAULT current_timestamp,
    turn_number int,
    is_last int2
);
COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."similar_docs" IS '该轮会话相似文档 hash 索引';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
DROP SEQUENCE IF EXISTS "turn_qa_seq";
CREATE SEQUENCE "turn_qa_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE turn_qa ALTER COLUMN turn_id SET DEFAULT nextval('turn_qa_seq'::regclass);
"""

TABLE_USER = """
DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user (
    user_id varchar(1000) PRIMARY KEY,
    account varchar(20) NOT NULL,
    password varchar(50) NOT NULL
);
COMMENT ON COLUMN "c_user"."user_id" IS '用户id';
COMMENT ON COLUMN "c_user"."account" IS '用户帐户';
COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表';
DROP SEQUENCE IF EXISTS "c_user_seq";
CREATE SEQUENCE "c_user_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE c_user ALTER COLUMN user_id SET DEFAULT nextval('c_user_seq'::regclass);
"""


class CRUD:
    def __init__(self, _db: UPostgresDB):
        self.db = _db

    def create_table(self):
        self.db.execute(TABLE_CHAT)
        self.db.execute(TABLE_TURN_QA)
        self.db.execute(TABLE_USER)

    def get_history(self, _chat_id):
        query = f'SELECT turn_number,question,answer,is_last,similar_docs FROM turn_qa WHERE chat_id=(%s) ORDER BY turn_number ASC'
        self.db.execute_args(query, (_chat_id,))
        ans = self.db.fetchall()
        return ans

    def get_last_history(self, _chat_id):
        query = f'SELECT question,answer,similar_docs FROM turn_qa WHERE chat_id=(%s) and is_last=1  ORDER BY turn_number ASC'
        self.db.execute_args(query, (_chat_id,))
        ans = self.db.fetchall()
        return ans

    def get_last_history_before_turn_id(self, _chat_id,turn_id):
        query = f'SELECT question,answer,similar_docs FROM turn_qa WHERE chat_id=(%s) and is_last=1 and turn_number<(%s) ORDER BY turn_number ASC'
        self.db.execute_args(query, (_chat_id,turn_id))
        ans = self.db.fetchall()
        return ans

    def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last, similar_docs=None):
        query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last, similar_docs) VALUES (%s,%s,%s,%s,%s,%s)'
        self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last, similar_docs))



    def insert_c_user(self, account, password):
        query = f'INSERT INTO c_user(account, password) VALUES (%s,%s)'
        self.db.execute_args(query, (account, password))

    def insert_chat(self, user_id, info, deleted):
        query = f'INSERT INTO chat(user_id, info, deleted) VALUES (%s,%s,%s)'
        self.db.execute_args(query, (user_id, info, deleted))

    def update_last(self, chat_id):
        query = f'UPDATE turn_qa SET is_last = 0 WHERE chat_id = (%s) AND is_last = 1'
        self.db.execute_args(query, (chat_id,))

    def update_turn_last(self, chat_id, turn_number):
        query = f'UPDATE turn_qa SET is_last = 0 WHERE chat_id = (%s) AND turn_number = (%s)'
        self.db.execute_args(query, (chat_id, turn_number))

    def user_exist_id(self, _user_id):
        query = f'SELECT * FROM c_user WHERE user_id = (%s)'
        self.db.execute_args(query, (_user_id,))
        return self.db.fetchone()

    def user_exist_account(self, _account):
        query = f'SELECT * FROM c_user WHERE account = (%s)'
        self.db.execute_args(query, (_account,))
        return self.db.fetchone()

    def user_exist_account_password(self, _account, _password):
        query = f'SELECT user_id FROM c_user WHERE account = (%s) AND password = (%s)'
        self.db.execute_args(query, (_account, _password))
        return self.db.fetchone()

    def chat_exist_chatid_userid(self, _chat_id, _user_id):
        query = f'SELECT * FROM chat WHERE chat_id = (%s) AND user_id = (%s)'
        self.db.execute_args(query, (_chat_id, _user_id))
        return self.db.fetchone()

    def get_chat_list_userid(self, _user_id):
        query = f'SELECT chat_id FROM chat WHERE user_id = (%s) AND deleted = 0 order by create_time desc'
        self.db.execute_args(query, (_user_id,))
        return self.db.fetchall()

    def get_chatinfo_from_chatid(self, _chat_id):
        query = f'SELECT info FROM chat WHERE chat_id = (%s)'
        self.db.execute_args(query, (_chat_id,))
        return self.db.fetchone()

    def delete_chat(self, _chat_id):
        query = f'UPDATE chat SET deleted = 1 WHERE chat_id = (%s)'
        self.db.execute_args(query, (_chat_id,))

    def get_last_question(self, _chat_id):
        query = f'SELECT question FROM turn_qa WHERE chat_id = (%s) AND turn_number = 1'
        self.db.execute_args(query, (_chat_id,))
        return self.db.fetchone()[0]

    def get_users(self):
        query = f'SELECT account FROM c_user'
        self.db.execute(query)
        return self.db.fetchall()

    def get_chats(self, account):
        query = f'SELECT chat.chat_id,chat.info FROM chat JOIN c_user ON chat.user_id = c_user.user_id WHERE c_user.account = (%s) ORDER BY chat.create_time DESC;'
        self.db.execute_args(query, (account,))
        return self.db.fetchall()

    def create_chat(self, user_id, info, deleted):
        query = f'INSERT INTO chat(user_id, info, deleted) VALUES (%s,%s,%s) RETURNING chat_id'
        self.db.execute_args(query, (user_id, info, deleted))
        ans = self.db.fetchall()[0][0]
        return ans

    def get_uersid_from_account(self, account):
        query = f'SELECT user_id FROM c_user WHERE account = (%s)'
        self.db.execute_args(query, (account, ))
        ans = self.db.fetchall()[0][0]
        print(ans)
        return ans

    def get_chat_info(self, chat_id):
        query = f'SELECT info FROM chat WHERE chat_id = (%s)'
        self.db.execute_args(query, (chat_id,))
        ans = self.db.fetchall()[0][0]
        print(ans)
        return ans

    def set_info(self, chat_id, info):
        query = f'UPDATE chat SET info = (%s) WHERE chat_id = (%s)'
        self.db.execute_args(query, (info, chat_id))

    def get_last_turn_num(self,chat_id):
        query = f'SELECT max(turn_number) FROM turn_qa WHERE chat_id = (%s)'
        self.db.execute_args(query, (chat_id,))
        ans = self.db.fetchone()[0]
        return ans