// Copyright (C) BABEC. All rights reserved.
// Copyright (C) THL A29 Limited, a Tencent company. All rights reserved.
//
// SPDX-License-Identifier: Apache-2.0

package common

import (
	"gorm.io/gorm"
)

var (
	mysqlDB = &gorm.DB{}
	dbc     = &DbStruct{}
)

//var db *gorm.DB

type DbStruct struct {
	db *gorm.DB
}

func InitDbStruct(db *gorm.DB) {
	dbc.db = db
}

func GetDbStruct() *DbStruct {
	return dbc
}

func (d *DbStruct) Query(dst interface{}, orderer string, where interface{}, args ...interface{}) *gorm.DB {
	var tx *gorm.DB

	if where == nil {
		if orderer != "" {
			tx = d.db.Order(orderer)
		}
		// 表示查询所有
		tx = d.db.Find(dst)
	} else {
		// 按照条件查询
		tx = d.db.Where(where, args...).Find(dst)
		if orderer != "" {
			tx = d.db.Order(orderer)
		}
	}
	return tx
}

func (d *DbStruct) QueryLimit(dst interface{}, orderer string, offset,
	limit int, where interface{}, count *int64, args ...interface{}) *gorm.DB {
	tx := d.db.Debug()
	if orderer != "" {
		tx = d.db.Order(orderer)
	}
	if where == nil {
		// 表示查询所有
		tx = tx.Offset(offset).Limit(limit).
			Find(dst).Limit(-1).Offset(-1).Count(count)
	} else {
		// 按照条件查询
		tx = tx.Debug().Where(where, args...).Offset(offset).Limit(limit).
			Find(dst).Limit(-1).Offset(-1).Count(count)
	}
	return tx
}

func (d *DbStruct) QueryCount(tableName string, where interface{}, args ...interface{}) int64 {
	var instanceNum int64
	d.db.Table(tableName).Where(where, args).Count(&instanceNum)
	return instanceNum
}

func (d *DbStruct) GetDb() *gorm.DB {
	return d.db
}

func (d *DbStruct) QueryFirst(dst interface{}, where interface{}, args ...interface{}) *gorm.DB {
	var tx *gorm.DB
	if where == nil {
		// 表示查询所有
		tx = d.db.First(dst)
	} else {
		// 按照条件查询
		tx = d.db.Where(where, args...).First(dst)
	}
	return tx
}

func (d *DbStruct) Create(m interface{}) *gorm.DB {
	return d.db.Create(m)
}

//Save update if info is existed,
func (d *DbStruct) Save(m interface{}) *gorm.DB {
	return d.db.Save(m)
}

func (d *DbStruct) Update(m interface{}, where interface{}, args ...interface{}) *gorm.DB {
	var tx *gorm.DB
	if where == nil {
		// 表示更新所有
		tx = d.db.Updates(m)
	} else {
		// 按照条件更新
		tx = d.db.Model(m).Where(where, args).Updates(m)
	}
	return tx
}

func (d *DbStruct) Delete(m interface{}, where interface{}, args ...interface{}) *gorm.DB {
	// 表示删除所有
	return d.db.Where(where, args).Delete(m)
}
