import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlparse, urlencode
from wsgiref.handlers import format_date_time
import websocket  # 使用websocket_client

URL_V1_5="ws://spark-api.xf-yun.com/v1.1/chat"
URL_V2="ws://spark-api.xf-yun.com/v2.1/chat"

Domain_V1_5="general"
Domain_V2="generalv2"

class SparkAPI:
    def __init__(self, APPID, APIKey, APISecret, Version="v1"):
        self.APPID = APPID
        self.APIKey = APIKey
        self.APISecret = APISecret

        if Version == "v1":
            self.Spark_url = URL_V1_5
            self.domain = Domain_V1_5
        elif Version == "v2":
            self.Spark_url = URL_V2
            self.domain = Domain_V2        
        self.host = urlparse(self.Spark_url).netloc
        self.path = urlparse(self.Spark_url).path
        self.answer = ""

    def create_url(self):
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))

        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"

        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()

        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')

        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
        }

        url = self.Spark_url + '?' + urlencode(v)
        return url

    def on_error(self, ws, error):
        print("### error:", error)

    def on_close(self, ws, one, two):
        print(" ")

    def on_open(self, ws):
        thread.start_new_thread(self.run, (ws,))

    def run(self, ws, *args):
        data = json.dumps(self.gen_params(appid=self.APPID, domain=self.domain, question=ws.question))
        ws.send(data)

    def on_message(self, ws, message):
        data = json.loads(message)
        code = data['header']['code']
        if code != 0:
            print(f'请求错误: {code}, {data}')
            ws.close()
        else:
            choices = data["payload"]["choices"]
            status = choices["status"]
            content = choices["text"][0]["content"]
            # print(content, end="")
            self.answer += content
            if status == 2:
                ws.close()

    def gen_params(self, appid, domain, question):
        data = {
            "header": {
                "app_id": appid,
                "uid": "1234"
            },
            "parameter": {
                "chat": {
                    "domain": domain,
                    "random_threshold": 0.5,
                    "max_tokens": 2048,
                    "auditing": "default"
                }
            },
            "payload": {
                "message": {
                    "text": question
                }
            }
        }
        return data

    def call(self, question):
        self.answer = ""
        wsUrl = self.create_url()
        websocket.enableTrace(False)
        ws = websocket.WebSocketApp(wsUrl, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open)
        ws.question = question
        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})