package job

import (
	"bytes"
	"chainweaver.org.cn/chainweaver/mira/mira-backend-service/config"
	"chainweaver.org.cn/chainweaver/mira/mira-backend-service/core/ctrl/common"
	"chainweaver.org.cn/chainweaver/mira/mira-backend-service/core/ctrl/handler"
	"chainweaver.org.cn/chainweaver/mira/mira-backend-service/core/service"
	"chainweaver.org.cn/chainweaver/mira/mira-common/logger"
	"chainweaver.org.cn/chainweaver/mira/mira-common/types"
	"chainweaver.org.cn/chainweaver/mira/mira-common/types/state"
	"encoding/json"
	"github.com/gin-gonic/gin"
	"github.com/pkg/errors"
	"github.com/spf13/cast"
	"net/http"
	"strconv"
)

type GetJobDagByPqlHandler struct {
	req  GetJobDagByPqlReq
	resp common.Response
}

type GetJobDagByPqlReq struct {
	common.BaseRequest

	PqlText     string            `json:"pqlText"`     // PQLText信息
	ProcessType state.ProcessType `json:"processType"` // 任务类型
}

type GetJobDagByPqlResp struct {
	Job *types.Job `json:"job"`
}

type OrgInfo struct {
	OrgId   string `json:"orgId"`
	OrgName string `json:"orgName"`
}

type SqlVo struct {
	SqlText       string             `json:"sqltext"`
	ModelType     int                `json:"modelType"`
	IsStream      int                `json:"isStream"`
	AssetInfoList []*types.AssetInfo `json:"assetInfoList"`
	ModelParams   []*ModelParamsVo   `json:"modelParams"`
	OrgInfo       OrgInfo            `json:"orgInfo"`
}

type ModelParamsVo struct {
	Name         string `json:"name"`
	Type         string `json:"type"`
	Version      string `json:"version"`
	CreateOrgDID string `json:"createOrgDID"`
	MethodName   string `json:"methodName"`
}

func (h *GetJobDagByPqlHandler) BindReq(c *gin.Context) error {
	if err := c.ShouldBindJSON(&h.req); err != nil {
		logger.Errorf("BindReq error: %v", err)
		h.resp.SetError(common.ErrCodeInvalidParameter, err.Error())
		return err
	}
	return nil
}

func (h *GetJobDagByPqlHandler) Process() {
	isStream := 1
	if h.req.ProcessType < state.ProcessType_Streaming {
		isStream = 0
	}

	// 获取平台信息
	platformInfo, err := service.MiraIdaAccessServiceImpl.GetPlatformInfo(uint32(h.req.ChainInfoId), h.req.RequestId)
	if err != nil {
		logger.Errorf("get platform info failed, err: %v", err)
		h.resp.SetError(common.ErrCodeGetPlatformFail, err.Error())
		return
	}

	// 平台ID和平台名称
	partyId := cast.ToString(platformInfo.PlatformId)
	partyName := platformInfo.PlatformName

	orgInfo := OrgInfo{
		OrgId:   partyId,
		OrgName: partyName,
	}

	sqlVo := &SqlVo{
		SqlText:       h.req.PqlText,
		ModelType:     0,
		IsStream:      isStream,
		AssetInfoList: nil,
		ModelParams:   nil,
		OrgInfo:       orgInfo,
	}

	paramBytes, err := json.Marshal(sqlVo)
	if err != nil {
		logger.Errorf("json marshal failed, err: %v", err)
		h.resp.SetError(common.ErrCodeGetDAGFail, err.Error())
		return
	}

	type TableModule struct {
		Tables map[string]string `json:"tables"`
		Models []string          `json:"models"`
	}

	type PreviewDAGResult struct {
		Code int          `json:"code"`
		Data *TableModule `json:"data"`
	}

	previewDAGResult := &PreviewDAGResult{}
	_, err = ReadByteFromUrl(config.Conf.JobServiceUrl+"/v1/preview/dag", paramBytes, previewDAGResult)
	if err != nil {
		logger.Errorf("preview dag failed, err: %v", err)
		h.resp.SetError(common.ErrCodePreviewDAGFail, err.Error())
		return
	}

	if previewDAGResult.Code != 200 {
		logger.Errorf("preview dag failed, err: %v", errors.New("code is not 200"))
		h.resp.SetError(common.ErrCodePreviewDAGFail, "code is not 200")
		return
	}

	assetInfoList := make([]*types.AssetInfo, 0)
	for _, assetName := range previewDAGResult.Data.Tables {
		assetInfo, err := service.MiraIdaAccessServiceImpl.GetAssetByEnName(h.req.RequestId, h.req.ChainInfoId, assetName)
		if err != nil {
			logger.Errorf("get asset info failed, err: %v", err)
			h.resp.SetError(common.ErrCodeGetAssetByEnNameFailed, err.Error())
			return
		}
		assetInfoList = append(assetInfoList, assetInfo)
	}

	sqlVo.AssetInfoList = assetInfoList

	// commit dag
	commitParamsBytes, err := json.Marshal(sqlVo)
	if err != nil {
		logger.Errorf("json marshal failed, err: %v", err)
		h.resp.SetError(common.ErrCodeCommitDAGFail, err.Error())
		return
	}

	type JobCommit struct {
		Job *types.Job `json:"job"`
	}

	type CommitDAGResult struct {
		Code    int        `json:"code"`
		Message string     `json:"message"`
		Data    *JobCommit `json:"data"`
	}

	commitDAGResult := &CommitDAGResult{}
	_, err = ReadByteFromUrl(config.Conf.JobServiceUrl+"/v1/commit/dag", commitParamsBytes, commitDAGResult)
	if err != nil {
		logger.Errorf("commit dag failed, err: %v", err)
		h.resp.SetError(common.ErrCodeCommitDAGFail, err.Error())
		return
	}

	if commitDAGResult.Code != 200 {
		logger.Errorf("commit dag failed, err: %v", errors.New("code is not 200"))
		h.resp.SetError(common.ErrCodeCommitDAGFail, commitDAGResult.Message)
		return
	}

	logger.Info(" =========== jobCommitVO deserialzie ")

	job := commitDAGResult.Data.Job
	job.PqlText = h.req.PqlText
	job.ProcessType = h.req.ProcessType

	resultReceiverList := make([]*types.ResultReceiver, 0)
	resultReceiver := &types.ResultReceiver{
		Id:          cast.ToString(platformInfo.PlatformId),
		PartyId:     cast.ToString(platformInfo.PlatformId),
		PartyName:   platformInfo.PlatformName,
		IsEncrypted: false,
		PubKeyName:  "",
		PubKey:      "",
	}
	resultReceiverList = append(resultReceiverList, resultReceiver)

	job.ResultReceiverList = resultReceiverList

	for _, jobService := range job.ServiceList {
		job.PartyList = append(job.PartyList, &types.Party{
			PartyId:   jobService.PartyId,
			PartyName: jobService.PartyName,
		})
	}

	set := make(map[string]struct{}, len(job.PartyList))
	j := 0
	for _, party := range job.PartyList {
		_, ok := set[party.PartyId]
		if ok {
			continue
		}
		set[party.PartyId] = struct{}{}
		job.PartyList[j] = party
		j++
	}
	job.PartyList = job.PartyList[:j]

	for _, party := range job.PartyList {
		partyId, err := strconv.Atoi(party.PartyId)
		if err != nil {
			logger.Errorf("GetPlatformPKFail failed, partyId： %s, err: %v", party.PartyId, errors.New("code is not 200"))
			h.resp.SetError(common.ErrGetPlatformPKFail, "code is not 200")
			return
		}
		pk, err := service.MiraIdaAccessServiceImpl.GetPrivatePlatformPK(h.req.RequestId, h.req.ChainInfoId, partyId)
		if err != nil {
			logger.Errorf("GetPlatformPKFail failed, partyId： %s, err: %v", partyId, errors.New("code is not 200"))
			h.resp.SetError(common.ErrGetPlatformPKFail, "code is not 200")
			return
		}
		party.PubKey = pk
	}

	jobGraph := &GetJobDagByPqlResp{
		Job: job,
	}

	h.resp.SetData(jobGraph)
}

func (h *GetJobDagByPqlHandler) GetResponse() *common.Response {
	return &h.resp
}

func GetJobDagByPqlHandlerFunc(c *gin.Context) {
	handler.Run(&GetJobDagByPqlHandler{}, c)
	return
}

// ReadByteFromUrl 解析url返回的byte数据
func ReadByteFromUrl(url string, params []byte, data interface{}) ([]byte, error) {
	resp, err := http.Post(url, "application/json", bytes.NewReader(params))
	if err != nil {
		return nil, err
	}

	defer resp.Body.Close()
	buf := new(bytes.Buffer)
	_, err = buf.ReadFrom(resp.Body)
	if err != nil {
		return nil, err
	}

	b := buf.Bytes()
	if err = json.Unmarshal(b, &data); err != nil {
		return nil, err
	}
	return b, nil
}

func NewGetJobDagByPqlHandler(baseRequest common.BaseRequest, pqlText string, processType state.ProcessType) *GetJobDagByPqlHandler {
	return &GetJobDagByPqlHandler{
		req: GetJobDagByPqlReq{
			BaseRequest: baseRequest,
			PqlText:     pqlText,
			ProcessType: processType,
		},
	}
}
