# 参数优化脚本 - 止盈止损组合测试
import MetaTrader5 as mt5
import pandas as pd
import numpy as np
from datetime import datetime
import pytz
from itertools import product

# ================== 基础配置 ==================
SYMBOL = "EURUSD"
TIMEFRAME = mt5.TIMEFRAME_M15
START_DATE = datetime(2025, 1, 1)
END_DATE = datetime.now()
LOT_SIZE = 0.1

UTC = pytz.timezone("UTC")
BEIJING = pytz.timezone("Asia/Shanghai")

POINT = 0.00001
BREAK_MARGIN = 5 * POINT
SLIPPAGE = 1 * POINT
ATR_PERIOD = 14
EMA_PERIOD = 20
TREND_TIMEFRAME = mt5.TIMEFRAME_H4
MAX_TRADES_PER_DAY = 2
GAP_FILTER_MULT = 2.0

# ================== MT5初始化 ==================
if not mt5.initialize():
    print("MT5初始化失败")
    quit()

print(f"正在获取 {SYMBOL} M15 历史数据...")
rates = mt5.copy_rates_range(SYMBOL, TIMEFRAME, START_DATE, END_DATE)
if rates is None or len(rates) == 0:
    print("数据获取失败")
    mt5.shutdown()
    quit()

df = pd.DataFrame(rates)
df['time'] = pd.to_datetime(df['time'], unit='s')
df.set_index('time', inplace=True)
df.index = df.index.tz_localize(UTC).tz_convert(BEIJING)
df.sort_index(inplace=True)

# 计算ATR
df['tr'] = np.maximum(
    df['high'] - df['low'],
    np.maximum(
        abs(df['high'] - df['close'].shift(1)),
        abs(df['low'] - df['close'].shift(1))
    )
)
df['atr'] = df['tr'].rolling(ATR_PERIOD).mean()

# 获取H4 EMA20
print("正在获取H4数据...")
h4_rates = mt5.copy_rates_range(SYMBOL, TREND_TIMEFRAME, START_DATE, END_DATE)
trend_filter_enabled = False
if h4_rates is not None and len(h4_rates) > 0:
    h4_df = pd.DataFrame(h4_rates)
    h4_df['time'] = pd.to_datetime(h4_df['time'], unit='s')
    h4_df.set_index('time', inplace=True)
    h4_df.index = h4_df.index.tz_localize(UTC).tz_convert(BEIJING)
    h4_df.sort_index(inplace=True)
    h4_df['ema20'] = h4_df['close'].ewm(span=EMA_PERIOD, adjust=False).mean()
    df['h4_ema20'] = np.nan
    for idx in df.index:
        h4_slice = h4_df[h4_df.index <= idx]
        if not h4_slice.empty:
            df.loc[idx, 'h4_ema20'] = h4_slice['ema20'].iloc[-1]
    df['h4_ema20'] = df['h4_ema20'].ffill()
    trend_filter_enabled = True

print(f"数据范围：{df.index[0]} 至 {df.index[-1]}，共 {len(df)} 根K线")

def get_asia_range(day_date):
    start = day_date.replace(hour=8, minute=0, second=0, microsecond=0)
    end = day_date.replace(hour=14, minute=0, second=0, microsecond=0)
    mask = (df.index >= start) & (df.index < end)
    subset = df.loc[mask]
    if len(subset) == 0:
        return None, None, None
    return subset['high'].max(), subset['low'].min(), subset['high'].max() - subset['low'].min()

def get_trend_direction(current_idx):
    if not trend_filter_enabled or pd.isna(df.loc[current_idx, 'h4_ema20']):
        return 0
    ema_val = df.loc[current_idx, 'h4_ema20']
    price = df.loc[current_idx, 'close']
    if price > ema_val:
        return 1
    elif price < ema_val:
        return -1
    return 0

def run_backtest(atr_stop_mult, atr_trail_mult, profit_protect_mult, trail_trigger_mult, take_profit_mult):
    """运行单次回测"""
    trades = []
    open_trade = None
    last_asia_range = None
    daily_trade_count = {}
    
    for i in range(2, len(df)):
        current_time = df.index[i]
        current_candle = df.iloc[i]
        prev_candle = df.iloc[i-1]
        current_date = current_time.date()

        # 每日14:00计算亚洲区间
        if current_time.hour == 14 and current_time.minute == 0:
            day_start = current_time.replace(hour=0, minute=0, second=0, microsecond=0)
            asia_high, asia_low, asia_range = get_asia_range(day_start)
            
            skip_day = False
            if i >= 1 and asia_high is not None:
                prev_close = prev_candle['close']
                gap = abs(current_candle['open'] - prev_close)
                atr_prev = df.iloc[i-1]['atr']
                if not pd.isna(atr_prev) and gap > GAP_FILTER_MULT * atr_prev:
                    skip_day = True
            
            if not skip_day and asia_high is not None:
                last_asia_range = (asia_high, asia_low, asia_range)
            else:
                last_asia_range = None
            daily_trade_count[current_date] = 0

        # 信号检测
        if (open_trade is None and last_asia_range is not None and 
            14 <= current_time.hour < 16 and 
            daily_trade_count.get(current_date, 0) < MAX_TRADES_PER_DAY):
            
            asia_high, asia_low, asia_range = last_asia_range
            
            long_break = current_candle['high'] > asia_high + BREAK_MARGIN
            short_break = current_candle['low'] < asia_low - BREAK_MARGIN
            
            trend_dir = get_trend_direction(current_time)
            if trend_filter_enabled and trend_dir != 0:
                if long_break and trend_dir != 1:
                    long_break = False
                if short_break and trend_dir != -1:
                    short_break = False
            
            if long_break or short_break:
                if i+1 >= len(df):
                    continue
                next_candle = df.iloc[i+1]
                
                long_confirm = long_break and next_candle['close'] > asia_high + BREAK_MARGIN
                short_confirm = short_break and next_candle['close'] < asia_low - BREAK_MARGIN
                
                if long_confirm or short_confirm:
                    entry_time = next_candle.name
                    atr_val = df.loc[entry_time]['atr']
                    if pd.isna(atr_val):
                        atr_val = asia_range
                    
                    if long_confirm:
                        entry_price = next_candle['open'] + SLIPPAGE
                        stop_loss = entry_price - atr_val * atr_stop_mult
                        take_profit_1 = entry_price + atr_val * take_profit_mult
                        open_trade = {
                            'type': 'buy',
                            'entry_time': entry_time,
                            'entry_price': entry_price,
                            'stop_loss': stop_loss,
                            'take_profit_1': take_profit_1,
                            'half_closed': False,
                            'trail_activated': False,
                            'profit': 0.0,
                            'atr_entry': atr_val
                        }
                    
                    elif short_confirm:
                        entry_price = next_candle['open'] - SLIPPAGE
                        stop_loss = entry_price + atr_val * atr_stop_mult
                        take_profit_1 = entry_price - atr_val * take_profit_mult
                        open_trade = {
                            'type': 'sell',
                            'entry_time': entry_time,
                            'entry_price': entry_price,
                            'stop_loss': stop_loss,
                            'take_profit_1': take_profit_1,
                            'half_closed': False,
                            'trail_activated': False,
                            'profit': 0.0,
                            'atr_entry': atr_val
                        }
                    
                    daily_trade_count[current_date] = daily_trade_count.get(current_date, 0) + 1

        # 持仓管理
        if open_trade is not None:
            t = open_trade
            high = current_candle['high']
            low = current_candle['low']
            close = current_candle['close']
            atr_current = df.loc[current_time]['atr']
            if pd.isna(atr_current):
                atr_current = t['atr_entry']
            
            exit_now = False
            exit_price = None
            reason = ''
            
            if t['type'] == 'buy':
                current_profit = high - t['entry_price']
            else:
                current_profit = t['entry_price'] - low
            
            # 保本激活
            if not t['trail_activated'] and current_profit >= profit_protect_mult * atr_current:
                t['trail_activated'] = True
                if t['type'] == 'buy':
                    t['stop_loss'] = max(t['stop_loss'], t['entry_price'] + 2*POINT)
                else:
                    t['stop_loss'] = min(t['stop_loss'], t['entry_price'] - 2*POINT)
            
            # 追踪止损
            if t['trail_activated'] and current_profit >= trail_trigger_mult * atr_current:
                if t['type'] == 'buy':
                    new_stop = high - atr_current * atr_trail_mult
                    if new_stop > t['stop_loss']:
                        t['stop_loss'] = new_stop
                else:
                    new_stop = low + atr_current * atr_trail_mult
                    if new_stop < t['stop_loss']:
                        t['stop_loss'] = new_stop
            
            # 止损判断
            if t['type'] == 'buy':
                if low <= t['stop_loss']:
                    exit_price = t['stop_loss']
                    exit_now = True
                elif not t['half_closed'] and high >= t['take_profit_1']:
                    t['half_closed'] = True
                    half_profit = (t['take_profit_1'] - t['entry_price']) * 100000 * LOT_SIZE * 0.5
                    t['profit'] += half_profit
            else:
                if high >= t['stop_loss']:
                    exit_price = t['stop_loss']
                    exit_now = True
                elif not t['half_closed'] and low <= t['take_profit_1']:
                    t['half_closed'] = True
                    half_profit = (t['entry_price'] - t['take_profit_1']) * 100000 * LOT_SIZE * 0.5
                    t['profit'] += half_profit
            
            # 美盘平仓
            if not exit_now and current_time.hour == 20 and current_time.minute == 30:
                exit_price = close
                exit_now = True
            
            if exit_now:
                if t['type'] == 'buy':
                    remaining_profit = (exit_price - t['entry_price']) * 100000 * LOT_SIZE * (0.5 if t['half_closed'] else 1.0)
                else:
                    remaining_profit = (t['entry_price'] - exit_price) * 100000 * LOT_SIZE * (0.5 if t['half_closed'] else 1.0)
                
                t['profit'] += remaining_profit
                trades.append(t.copy())
                open_trade = None

    # 强制平仓
    if open_trade is not None:
        t = open_trade
        last_candle = df.iloc[-1]
        exit_price = last_candle['close']
        if t['type'] == 'buy':
            remaining_profit = (exit_price - t['entry_price']) * 100000 * LOT_SIZE * (0.5 if t['half_closed'] else 1.0)
        else:
            remaining_profit = (t['entry_price'] - exit_price) * 100000 * LOT_SIZE * (0.5 if t['half_closed'] else 1.0)
        t['profit'] += remaining_profit
        trades.append(t.copy())
    
    if not trades:
        return {
            'total_trades': 0,
            'total_profit': 0,
            'win_rate': 0,
            'avg_win': 0,
            'avg_loss': 0,
            'profit_factor': 0,
            'max_dd': 0,
            'sharpe': 0
        }
    
    df_trades = pd.DataFrame(trades)
    total_profit = df_trades['profit'].sum()
    wins = df_trades[df_trades['profit'] > 0]
    losses = df_trades[df_trades['profit'] <= 0]
    
    win_rate = len(wins) / len(df_trades) * 100 if len(df_trades) > 0 else 0
    avg_win = wins['profit'].mean() if len(wins) > 0 else 0
    avg_loss = losses['profit'].mean() if len(losses) > 0 else 0
    profit_factor = avg_win / abs(avg_loss) if len(wins) > 0 and len(losses) > 0 else 0
    
    # 计算月度统计
    df_trades['entry_time'] = pd.to_datetime(df_trades['entry_time'])
    df_trades['month'] = df_trades['entry_time'].dt.to_period('M')
    monthly = df_trades.groupby('month')['profit'].sum()
    monthly_cum = monthly.cumsum()
    max_dd = (monthly_cum.cummax() - monthly_cum).max() if len(monthly) > 0 else 0
    
    sharpe = 0
    if len(monthly) > 1:
        monthly_returns = monthly.values
        sharpe = np.mean(monthly_returns) / np.std(monthly_returns) * np.sqrt(12) if np.std(monthly_returns) != 0 else 0
    
    return {
        'total_trades': len(df_trades),
        'total_profit': total_profit,
        'win_rate': win_rate,
        'avg_win': avg_win,
        'avg_loss': avg_loss,
        'profit_factor': profit_factor,
        'max_dd': max_dd,
        'sharpe': sharpe
    }

# ================== 参数优化网格 ==================
print("\n" + "="*60)
print("开始参数优化...")
print("="*60)

# 定义参数网格
param_grid = {
    'atr_stop_mult': [1.2, 1.5, 1.8, 2.0],           # 初始止损倍数
    'atr_trail_mult': [0.8, 1.0, 1.2, 1.5],          # 追踪距离倍数
    'profit_protect_mult': [1.5, 1.8, 2.0, 2.2],     # 保本触发倍数
    'trail_trigger_mult': [2.5, 3.0, 3.5, 4.0],      # 追踪触发倍数
    'take_profit_mult': [2.5, 3.0, 3.5, 4.0, 4.5]    # 第一止盈倍数
}

results = []
total_combinations = np.prod([len(v) for v in param_grid.values()])
print(f"待测试参数组合数量：{total_combinations}")

count = 0
best_result = None
best_params = None
best_score = -999

for atr_stop, atr_trail, profit_protect, trail_trigger, take_profit in product(
    param_grid['atr_stop_mult'],
    param_grid['atr_trail_mult'],
    param_grid['profit_protect_mult'],
    param_grid['trail_trigger_mult'],
    param_grid['take_profit_mult']
):
    count += 1
    if count % 100 == 0:
        print(f"进度：{count}/{total_combinations} ({count/total_combinations*100:.1f}%)")
    
    result = run_backtest(atr_stop, atr_trail, profit_protect, trail_trigger, take_profit)
    
    # 综合评分：总盈利 + 盈亏比 + 夏普比率 - 最大回撤惩罚
    score = (result['total_profit'] * 0.4 + 
             result['profit_factor'] * 30 + 
             result['sharpe'] * 20 - 
             result['max_dd'] * 0.3)
    
    results.append({
        'atr_stop_mult': atr_stop,
        'atr_trail_mult': atr_trail,
        'profit_protect_mult': profit_protect,
        'trail_trigger_mult': trail_trigger,
        'take_profit_mult': take_profit,
        'total_trades': result['total_trades'],
        'total_profit': result['total_profit'],
        'win_rate': result['win_rate'],
        'avg_win': result['avg_win'],
        'avg_loss': result['avg_loss'],
        'profit_factor': result['profit_factor'],
        'max_dd': result['max_dd'],
        'sharpe': result['sharpe'],
        'score': score
    })
    
    if score > best_score:
        best_score = score
        best_result = result
        best_params = {
            'atr_stop_mult': atr_stop,
            'atr_trail_mult': atr_trail,
            'profit_protect_mult': profit_protect,
            'trail_trigger_mult': trail_trigger,
            'take_profit_mult': take_profit
        }

# 保存结果
df_results = pd.DataFrame(results)
df_results = df_results.sort_values('score', ascending=False)
df_results.to_csv("optimization_results.csv", index=False)

print("\n" + "="*60)
print("优化完成！")
print("="*60)

print("\n【最佳参数组合】")
print(f"ATR_STOP_MULT = {best_params['atr_stop_mult']}")
print(f"ATR_TRAIL_MULT = {best_params['atr_trail_mult']}")
print(f"PROFIT_PROTECT_MULT = {best_params['profit_protect_mult']}")
print(f"TRAIL_TRIGGER_MULT = {best_params['trail_trigger_mult']}")
print(f"TAKE_PROFIT_MULT = {best_params['take_profit_mult']}")

print("\n【最佳参数表现】")
print(f"总交易次数：{best_result['total_trades']}")
print(f"总盈亏：{best_result['total_profit']:.2f} USD")
print(f"胜率：{best_result['win_rate']:.1f}%")
print(f"平均盈利：{best_result['avg_win']:.2f} USD")
print(f"平均亏损：{best_result['avg_loss']:.2f} USD")
print(f"盈亏比：{best_result['profit_factor']:.2f}")
print(f"最大回撤：{best_result['max_dd']:.2f} USD")
print(f"夏普比率：{best_result['sharpe']:.2f}")

print("\n【TOP 10 参数组合】")
print(df_results[['atr_stop_mult', 'atr_trail_mult', 'profit_protect_mult', 
                  'trail_trigger_mult', 'take_profit_mult', 'total_profit', 
                  'profit_factor', 'sharpe', 'score']].head(10).to_string())

mt5.shutdown()
print("\n优化结果已保存至 optimization_results.csv")