#include "CryptoBacktest.h"

#include "CoreConcept.h"
#include "Exceptions.h"
#include "ScalarImp.h"
#include "SmartPointer.h"
#include "Swordfish.h"
#include "Types.h"

using namespace ddb;

CryptoBacktest::CryptoBacktest(ddb::SmartPointer<DDBConnection> conn) : conn_(conn) {}

SmartPointer<BacktestWrapper> CryptoBacktest::runCryptoBacktest(ConstantSP strategyName, DictionarySP userConfig,
                                                                DictionarySP eventCallbacks) {
    ConstantSP strategyName_ = strategyName;
    ConstantSP startDate = userConfig->getMember("startDate");
    ConstantSP endDate = userConfig->getMember("endDate");
    VectorSP universes = userConfig->getMember("Universe");
    DictionarySP userConfig_ = userConfig->getValue();

    ConstantSP target = new String("futuresOrSpot");
    ConstantSP resultSP = new Bool(false);
    userConfig_->contain(target, resultSP);
    if (!resultSP->getBool()) {
        userConfig_->set(target, new Long(2));
    }
    target = new String("Universe");
    userConfig_->contain(target, resultSP);
    if (!resultSP->getBool() or universes->size() == 0) {
        throw RuntimeException("没有设置回测的标的池，请先设置标的池！");
    }
    ConstantSP futuresOrSpot = userConfig_->getMember("futuresOrSpot");
    if (futuresOrSpot->getInt() == 0 or futuresOrSpot->getInt() == 2) {
        TableSP fundingRate = getCryptoFundingRate(universes, startDate, endDate);
        if (fundingRate->size() == 0) {
            throw RuntimeException("表（dfs://Binance_KLine_Day，fundingRate）中没有永续合约的费率数据！");
        }
        userConfig_->set("fundingRate", fundingRate);
    }
    long dataType = userConfig_->getMember("dataType")->getLong();
    if (dataType == 3) {
        TableSP data = getBinanceMinKLine(universes, startDate, endDate, futuresOrSpot);
        if (data->size() < 1) {
            throw RuntimeException(
                "表（dfs://Binance_KLine_Day，KLine_Min_Future或者KLine_Min_Spot）中没有对应的数字货币分钟数据！");
        }
        TableSP securityReference = getCryptoSecurityReference(data);
        auto engine = BacktestWrapper::createBacktester(strategyName_, userConfig_, eventCallbacks, new Bool(false),
                                                        securityReference);
        cout << "开始执行回测！" << endl;
        engine->appendQuotationMsg(data);
        auto session = engine->getSession();
        session->getHeap()->addItem("data", data);
        TableSP endFlag = DolphinDBLib::execute(
            session,
            "endFlag = select top 1 * from data order by tradeTime desc; update endFlag set symbol = 'END'; endFlag");
        engine->appendQuotationMsg(endFlag);
        cout << "回测执行结束！可以通过回测引擎获取回测结果！" << endl;
        return engine;
    } else if (dataType == 4) {
        throw RuntimeException(
            "表（dfs://Binance_KLine_Day，KLine_Min_Future或者KLine_Min_Spot）中没有对应的的数字货币日线数据！");
    } else if (dataType == 1) {
        TableSP securityReference = getCryptoSecurityReferenceFromDepth(universes, startDate, endDate);
        if (securityReference->size() < 1) {
            throw RuntimeException("表（dfs://Binance_Tick，depth）中没有对应的的数字货币快照行情数据！");
        }
        auto engine = BacktestWrapper::createBacktester(strategyName, userConfig_, eventCallbacks, new Bool(false),
                                                        securityReference);
        cout << "开始获取快照行情数据数据！并执行回测" << endl;
        VectorSP dates = DolphinDBLib::execute(startDate->getString() + ".." + endDate->getString());
        TableSP data = nullptr;
        for (INDEX i = 0; i < dates->size(); ++i) {
            ConstantSP idate = dates->get(i);
            data = getSnapshotData(universes, idate, idate);
            engine->appendQuotationMsg(data);
        }
        auto session = engine->getSession();
        session->getHeap()->addItem("data", data);
        TableSP endFlag = DolphinDBLib::execute(
            session,
            "endFlag = select top 1 * from data order by timestamp desc; update endFlag set symbol = 'END'; endFlag");
        engine->appendQuotationMsg(endFlag);
        cout << "回测执行结束！可以通过回测引擎获取回测结果！" << endl;
        return engine;
    }

    return nullptr;
}

ddb::DictionarySP CryptoBacktest::getUnifiedConfig(ddb::ConstantSP startDate, ddb::ConstantSP endDate,
                                                   ddb::ConstantSP dataType) {
    DictionarySP userConfig = Util::createDictionary(DT_STRING, nullptr, DT_ANY, nullptr);
    userConfig->set("startDate", startDate);
    userConfig->set("endDate", endDate);
    userConfig->set("strategyGroup", new String("cryptocurrency"));
    DictionarySP cash = Util::createDictionary(DT_STRING, nullptr, DT_DOUBLE, nullptr);
    cash->set("spot", new Double(1000000));
    cash->set("futures", new Double(1000000));
    cash->set("option", new Double(1000000));
    userConfig->set("cash", cash);
    userConfig->set("dataType", dataType);
    VectorSP sym = Util::createVector(DT_STRING, 1, 1);
    sym->set(0, new String("BTCUSDT"));
    userConfig->set("Universe", sym);
    DictionarySP p = Util::createDictionary(DT_STRING, nullptr, DT_ANY, nullptr);
    VectorSP sym_ = Util::createVector(DT_STRING, 1, 1);
    sym_->set(0, new String("BTCUSDT_futures"));
    p->set("Universe", sym_);
    TableSP log = Util::createTable({"tradeDate", "time", "info"}, {DT_DATE, DT_TIMESTAMP, DT_STRING}, 0, 0);
    p->set("log", log);
    userConfig->set("context", p);

    return userConfig;
}

ddb::TableSP CryptoBacktest::getCryptoFundingRate(ddb::VectorSP codes, ddb::ConstantSP begt, ddb::ConstantSP endt) {
    conn_->upload("codes", codes);
    conn_->upload("begt", begt);
    conn_->upload("endt", endt);

    TableSP ret = conn_->run(R"(
//获取资金费率
def getCryptoFundingRate(codes,begt,endt){
	return select  symbol+"_futures" as symbol,settlementTime,
	decimal128(lastFundingRate,8) as lastFundingRate from 
	loadTable("dfs://CryptocurrencyDay","fundingRate")  where settlementTime.date()>=begt and 
	settlementTime.date()<=endt and symbol in codes order by settlementTime
}

getCryptoFundingRate(codes,begt,endt)
        )");

    return ret;
}

ddb::TableSP CryptoBacktest::getBinanceMinKLine(ddb::VectorSP codes, ddb::ConstantSP begt, ddb::ConstantSP endt,
                                                ddb::ConstantSP futuresOrSpot) {
    conn_->upload("codes", codes);
    conn_->upload("begt", begt);
    conn_->upload("endt", endt);
    conn_->upload("futuresOrSpot", futuresOrSpot);

    TableSP ret = conn_->run(R"(
def getBinanceMinKLine(codes,begt,endt,futuresOrSpot=0){
 	'''
 	futuresOrSpot=0，期货；1，现货；2、现货和期货都有
 	'''
	if(futuresOrSpot==0){
		data =select symbol+"_futures" as symbol,symbolSource as  symbolSource,eventTime as tradeTime,
		eventTime.date() as tradingDay,decimal128(open,8) as open,decimal128(low,8) as low,
		decimal128(high,8) as high,decimal128(close,8) as close,decimal128(volume,8) as volume,decimal128(quoteVolume,8) as amount,
		decimal128(0.,8) as upLimitPrice,decimal128(0.,8) as downLimitPrice,fixedLengthArrayVector(10.*[takerBuyBase,takerBuyQuote,volCcy]) as signal,
		decimal128(0.,8) as prevClosePrice,decimal128(0.,8) as settlementPrice,decimal128(0.,8) as prevSettlementPrice, 2 as contractType  
		from loadTable("dfs://CryptocurrencyKLine","minKLine")
		where symbol in codes and symbolSource like "%Futures" and
		eventTime.date()>=begt and eventTime.date()<=endt order by tradeTime
	}
	else if(futuresOrSpot==1){
		data =select symbol+"_spot" as symbol,symbolSource as  symbolSource,eventTime as tradeTime,
		eventTime.date() as tradingDay,decimal128(open,8) as open,decimal128(low,8) as low,
		decimal128(high,8) as high,decimal128(close,8) as close,decimal128(volume,8) as volume,decimal128(quoteVolume,8) as amount,
		decimal128(0.,8) as upLimitPrice,decimal128(0.,8) as downLimitPrice,fixedLengthArrayVector(10.*[takerBuyBase,takerBuyQuote,volCcy]) as signal,
		decimal128(0.,8) as prevClosePrice,decimal128(0.,8) as settlementPrice,decimal128(0.,8) as prevSettlementPrice,0 as contractType  
		from loadTable("dfs://CryptocurrencyKLine","minKLine")
		where symbol in codes and symbolSource like "%Spot" and
		eventTime.date()>=begt and eventTime.date()<=endt order by tradeTime
	}
	else{
		futures =select symbol+"_futures" as symbol,symbolSource as  symbolSource,eventTime as tradeTime,
		eventTime.date() as tradingDay,decimal128(open,8) as open,decimal128(low,8) as low,
		decimal128(high,8) as high,decimal128(close,8) as close,decimal128(volume,8) as volume,decimal128(quoteVolume,8) as amount,
		decimal128(0.,8) as upLimitPrice,decimal128(0.,8) as downLimitPrice,fixedLengthArrayVector(10.*[takerBuyBase,takerBuyQuote,volCcy]) as signal,
		decimal128(0.,8) as prevClosePrice,decimal128(0.,8) as settlementPrice,decimal128(0.,8) as prevSettlementPrice,2 as contractType  
		from loadTable("dfs://CryptocurrencyKLine","minKLine")
		where symbol in codes and symbolSource like "%Futures" and
		eventTime.date()>=begt and eventTime.date()<=endt order by tradeTime
		spot =select symbol+"_spot" as symbol,symbolSource as  symbolSource,eventTime as tradeTime,
		eventTime.date() as tradingDay,decimal128(open,8) as open,decimal128(low,8) as low,
		decimal128(high,8) as high,decimal128(close,8) as close,decimal128(volume,8) as volume,decimal128(quoteVolume,8) as amount,
		decimal128(0.,8) as upLimitPrice,decimal128(0.,8) as downLimitPrice,fixedLengthArrayVector(10.*[takerBuyBase,takerBuyQuote,volCcy]) as signal,
		decimal128(0.,8) as prevClosePrice,decimal128(0.,8) as settlementPrice,decimal128(0.,8) as prevSettlementPrice,0 as contractType  
		from loadTable("dfs://CryptocurrencyKLine","minKLine")
		where symbol in codes and symbolSource like "%Spot" and
		eventTime.date()>=begt and eventTime.date()<=endt order by tradeTime
		data = select * from futures.append!(spot) order by tradeTime 
	}
	return data
}

getBinanceMinKLine(codes,begt,endt,futuresOrSpot)
        )");

    return ret;
}

ddb::TableSP CryptoBacktest::getCryptoSecurityReference(ddb::TableSP data) {
    auto session = DolphinDBLib::createSession();
    session->getHeap()->addItem("data", data);

    TableSP ret = DolphinDBLib::execute(session, R"(
def getCryptoSecurityReference(data){
	securityReference=select last(contractType)  as contractType from data  group by symbol, symbolSource
	update securityReference set optType = 1
	update securityReference set strikePrice = decimal128(0, 8)
	update securityReference set contractSize = iif(contractType==2,decimal128(100.,8),decimal128(1.,8))
	update securityReference set marginRatio = decimal128(0.2,8)
	update securityReference set tradeUnit = iif(contractType==2,decimal128(0.2,8),decimal128(1.,8))
	update securityReference set priceUnit = decimal128(0.,8)
	update securityReference set priceTick = decimal128(0.,8)
	update securityReference set takerRate = decimal128(0.,8)
	update securityReference set makerRate = decimal128(0.,8)
	update securityReference set deliveryCommissionMode = iif(contractType==2,1,2)
	update securityReference set fundingSettlementMode = iif(contractType==2,1,2)
	update securityReference set lastTradeTime = timestamp()
	return securityReference	
}

getCryptoSecurityReference(data)
        )");

    session->getHeap()->removeItem("data");
    return ret;
}

ddb::TableSP CryptoBacktest::getCryptoSecurityReferenceFromDepth(ddb::VectorSP codes, ddb::ConstantSP begt,
                                                                 ddb::ConstantSP endt) {
    conn_->upload("codes", codes);
    conn_->upload("begt", begt);
    conn_->upload("endt", endt);

    TableSP ret = conn_->run(R"(
def getCryptoSecurityReference(data){
	securityReference=select last(contractType)  as contractType from data  group by symbol, symbolSource
	update securityReference set optType = 1
	update securityReference set strikePrice = decimal128(0, 8)
	update securityReference set contractSize = iif(contractType==2,decimal128(100.,8),decimal128(1.,8))
	update securityReference set marginRatio = decimal128(0.2,8)
	update securityReference set tradeUnit = iif(contractType==2,decimal128(0.2,8),decimal128(1.,8))
	update securityReference set priceUnit = decimal128(0.,8)
	update securityReference set priceTick = decimal128(0.,8)
	update securityReference set takerRate = decimal128(0.,8)
	update securityReference set makerRate = decimal128(0.,8)
	update securityReference set deliveryCommissionMode = iif(contractType==2,1,2)
	update securityReference set fundingSettlementMode = iif(contractType==2,1,2)
	update securityReference set lastTradeTime = timestamp()
	return securityReference	
}

def getCryptoSecurityReferenceFromDepth(codes,begt,endt){
	depth =select symbolSource
			from loadTable("dfs://CryptocurrencyTick","depth") where symbol	 in codes and 
			eventTime.date()>=begt and eventTime.date() <=endt group by symbol,symbolSource

	update depth set contractType=iif(strpos(symbolSource,"Futures")>0,2,0)
	update depth set symbol=symbol+"_"+iif(strpos(symbolSource,"Futures")>0,"futures","spot")
	return getCryptoSecurityReference(depth)
}

getCryptoSecurityReferenceFromDepth(codes,begt,endt)
        )");

    return ret;
}

ddb::TableSP CryptoBacktest::getSnapshotData(ddb::VectorSP codes, ddb::ConstantSP begt, ddb::ConstantSP endt) {
    conn_->upload("codes", codes);
    conn_->upload("begt", begt);
    conn_->upload("endt", endt);

    TableSP ret = conn_->run(R"(
 def getSnapshotData(codes,begt,endt){
	//获取快照数据
	//
	//print("开始获取快照行情数据数据！")
	snapshotData =table(100000:0,[`symbol],[SYMBOL])
	for (idate in begt..endt){
		depth = select symbol ,symbolSource ,eventTime as timestamp,
			eventTime.date() as tradingDay,decimal128(bidPrice,8) as bidPrice,decimal128(bidQty,8) as bidQty,
			decimal128(askPrice,8) as offerPrice,decimal128(askQty,8) as offerQty
			from loadTable("dfs://CryptocurrencyTick","depth") where symbol in codes and 
			eventTime.date()==idate order by timestamp
		update depth set symbol = symbol+"_"+iif(strpos(symbolSource,"Futures")>0,"futures","spot")
		aggTrade = select symbol ,symbolSource, eventTime as timestamp,price as lastPrice,
				decimal128(quantity,8) as quantity from loadTable("dfs://CryptocurrencyTick","trade") where symbol in codes and 
				eventTime.date()==idate order by timestamp
		update aggTrade set symbol=symbol+"_"+iif(strpos(symbolSource,"Futures")>0,"futures","spot")
		if(depth.size()<1 and snapshotData.size()>0){
			ss = "表（dfs://CryptocurrencyTick，depth）中 没有"
			throw(ss+string(concat(codes,","))+ ": "+ string(idate)+"的深度行情数据！")
			return		
		}
		if(aggTrade.size()<1 and snapshotData.size()>0){
			ss = "表（dfs://CryptocurrencyTick，trade）中 没有"
			throw(ss+string(concat(codes,","))+ ": "+ string(idate)+"的交易数据！")
			return		
		}
		//合并深度行情+成交行情数据
		messageTable = wj(depth, aggTrade, 0:0, 
		<[last(lastPrice) as lastPrice, sum(quantity) as totalBidQty, sum(quantity) as totalOfferQty, max(lastPrice) as highPrice, 
		min(lastPrice) as lowPrice]>, 
		`symbol`timestamp)
		depth = NULL
		aggTrade = NULL
		// fill the missing columns
		signal = take(array(DOUBLE[], 0).append!(0), count(messageTable))
		messageTable = select  symbol, symbolSource,  timestamp, date(timestamp) as tradingDay, 
		decimal128(lastPrice,8) as lastPrice, decimal128(0., 8) as upLimitPrice, decimal128(0., 8) as downLimitPrice, 
		decimal128(totalBidQty,8) as totalBidQty,  
		decimal128(totalOfferQty,8) as totalOfferQty,
		bidPrice, bidQty,  offerPrice,  offerQty, decimal128(highPrice,8) as highPrice, decimal128(lowPrice,8) as lowPrice, signal,
		decimal128(0, 8) as prevClosePrice, decimal128(0, 8) as settlementPrice, decimal128(0, 8) as prevSettlementPrice, 
		2 as contractType from messageTable
		update messageTable set lastPrice = ffill(lastPrice) context by symbol 
		update messageTable set lastPrice = bfill(lastPrice)  context by symbol 
		update messageTable set totalBidQty = nullFill(totalBidQty, 0)  context by symbol 
		update messageTable set  totalOfferQty = nullFill(totalBidQty, 0)  context by symbol 
		update messageTable set contractType=iif(strpos(symbolSource,"Futures")>0,2,0)
		if(snapshotData.size() < 1){
			snapshotData = messageTable
		}
		else{
			snapshotData=snapshotData.append!(messageTable)
		}
		messageTable = NULL			
	}
	if(snapshotData.size() < 1){
		ss = "表dfs://CryptocurrencyTick，depth 或者 trade 中 没有 "
		throw(ss+string(concat(codes,","))+ ": "+ string(begt)+":"+string(endt)+" 的深度行情或者成交数据，请先检查数据！")
	}
	return select * from snapshotData order by timestamp
}

getSnapshotData(codes,begt,endt)
        )");

    return ret;
}