###############################################################################
# Block 1 : Imports
###############################################################################
import sys, time, os, json, joblib
from datetime import datetime

import MetaTrader5 as mt5
import pandas as pd, numpy as np
from sklearn.linear_model import LogisticRegression

if sys.version_info >= (3, 12):
    print(f"{datetime.now()}: WARNING – MetaTrader5 wheels exist only for "
          f"Python ≤ 3.11; you are on {sys.version.split()[0]}", flush=True)

###############################################################################
# Block 2 : User settings
###############################################################################
LOGIN, PASSWORD, SERVER = xxxxx, "xxxxx", "server-Demo" #here is you insert your mt5 data account 

SYMBOL         = "XAUUSD"                 # change when you switch instruments
TF_ENTRY       = mt5.TIMEFRAME_M5
TF_TREND       = mt5.TIMEFRAME_H1
EMA_PERIOD     = 22
LOT_SIZE       = 0.01
RR_RATIO       = 2

THRESHOLD       = 0.60    # trade only if P(win) ≥ 60 %
WARMUP_TRADES   = 30      # first 30 trades always executed

BASE_DIR   = r"C:\Users\folder"   #here is folder where the data is saved
SYMBOL_DIR = os.path.join(BASE_DIR, SYMBOL)
os.makedirs(SYMBOL_DIR, exist_ok=True)

FEATURE_FILE = os.path.join(SYMBOL_DIR, "features.txt")
MODEL_FILE   = os.path.join(SYMBOL_DIR, "lr_model.joblib")

RETRAIN_EVERY = 28       # retrain after this many closed trades
LOOP_SECONDS  = 300       # 5-minute cadence

###############################################################################
# Block 3 : Strategy & feature engineering
###############################################################################
def fetch_rates(timeframe, bars=500):
    rates = mt5.copy_rates_from_pos(SYMBOL, timeframe, 0, bars)
    return pd.DataFrame(rates) if rates is not None else pd.DataFrame()

def indicator_pack(df):
    if df.empty:
        return df
    df["ema"] = df["close"].ewm(span=EMA_PERIOD, adjust=False).mean()

    hl = df["high"] - df["low"]
    hc = (df["high"] - df["close"].shift()).abs()
    lc = (df["low"]  - df["close"].shift()).abs()
    tr = pd.concat([hl, hc, lc], axis=1).max(axis=1)
    df["atr"] = tr.rolling(14).mean()

    up, dn = df["high"].diff(), -df["low"].diff()
    plus_dm  = np.where((up > dn) & (up > 0), up, 0.)
    minus_dm = np.where((dn > up) & (dn > 0), dn, 0.)
    tr14     = tr.rolling(14).sum()
    plus_di  = 100 * pd.Series(plus_dm).rolling(14).sum() / tr14
    minus_di = 100 * pd.Series(minus_dm).rolling(14).sum() / tr14
    dx       = (abs(plus_di - minus_di) / (plus_di + minus_di)) * 100
    df["adx"]= dx.rolling(14).mean()
    return df

def raw_signal():
    entry_df = indicator_pack(fetch_rates(TF_ENTRY, 300))
    trend_df = indicator_pack(fetch_rates(TF_TREND, 300))
    if entry_df.empty or trend_df.empty:
        return None, None, None, None      # feed problem

    last, prev   = entry_df.iloc[-1], entry_df.iloc[-2]
    trend_last   = trend_df.iloc[-1]
    atr_median   = entry_df["atr"].median()

    crossed_up   = prev.close < prev.ema and last.close > last.ema
    crossed_down = prev.close > prev.ema and last.close < last.ema
    trend_up     = trend_last.close > trend_last.ema
    trend_down   = trend_last.close < trend_last.ema
    atr_ok       = last.atr > atr_median
    adx_ok       = last.adx > 20

    # Print filter logic for both BUY and SELL
    GREEN = "\033[92m"
    RED   = "\033[91m"
    END   = "\033[0m"

    # BUY logic breakdown
    buy_checks = [
        ("crossed_up", crossed_up),
        ("trend_up", trend_up),
        ("atr_ok", atr_ok),
        ("adx_ok", adx_ok)
    ]
    buy_passed = sum(1 for _, ok in buy_checks if ok)
    print(f"{datetime.now()}: BUY filter progress:", flush=True)
    for label, ok in buy_checks:
        status = f"{GREEN}✔{END}" if ok else f"{RED}✖{END}"
        print(f"   {status} {label}: {ok}", flush=True)
    print(f"   {buy_passed}/4 filters passed for BUY", flush=True)

    # SELL logic breakdown
    sell_checks = [
        ("crossed_down", crossed_down),
        ("trend_down", trend_down),
        ("atr_ok", atr_ok),
        ("adx_ok", adx_ok)
    ]
    sell_passed = sum(1 for _, ok in sell_checks if ok)
    print(f"{datetime.now()}: SELL filter progress:", flush=True)
    for label, ok in sell_checks:
        status = f"{GREEN}✔{END}" if ok else f"{RED}✖{END}"
        print(f"   {status} {label}: {ok}", flush=True)
    print(f"   {sell_passed}/4 filters passed for SELL", flush=True)

    if crossed_up and trend_up and atr_ok and adx_ok:
        return "BUY", last, trend_last, atr_median
    if crossed_down and trend_down and atr_ok and adx_ok:
        return "SELL", last, trend_last, atr_median
    return None, last, trend_last, atr_median

def build_features(candle, trend_candle, atr_median):
    return {
        "timestamp"       : candle.time,
        "hour"            : datetime.fromtimestamp(candle.time).hour,
        "candle_size"     : candle.high - candle.low,
        "ema_distance"    : abs(candle.close - candle.ema),
        "atr"             : candle.atr,
        "adx"             : candle.adx,
        "volume"          : candle.tick_volume,
        "trend_above_ema" : int(trend_candle.close > trend_candle.ema),
        "range_status"    : int(candle.adx < 20),            # 1 = range-bound
        "volatility_level": int(candle.atr > atr_median),    # 1 = high vol
    }

###############################################################################
# Block 4 : Learning engine
###############################################################################
def save_feature(feat):
    with open(FEATURE_FILE, "a") as f:
        f.write(json.dumps(feat) + "\n")

def load_dataset():
    if not os.path.isfile(FEATURE_FILE):
        return pd.DataFrame()
    with open(FEATURE_FILE) as f:
        df = pd.DataFrame(json.loads(l) for l in f)
    if "entered" not in df.columns:        # legacy rows → assume “no trade”
        df["entered"] = 0
    return df

def train_model(df):
    cols = ["hour","candle_size","ema_distance","atr","adx",
            "volume","trend_above_ema","range_status","volatility_level"]
    trades = df[df.entered == 1]
    if trades.empty:
        return None
    X, y = trades[cols], trades["outcome"]
    model = LogisticRegression(max_iter=500).fit(X, y)
    joblib.dump(model, MODEL_FILE)
    return model

def get_model():
    return joblib.load(MODEL_FILE) if os.path.isfile(MODEL_FILE) else None

###############################################################################
# Block 5 : Main loop
###############################################################################
if not mt5.initialize(server=SERVER, login=LOGIN, password=PASSWORD):
    print(f"{datetime.now()}: MT5 initialise failed – {mt5.last_error()}", flush=True)
    sys.exit()

print(f"{datetime.now()}: Connected – {SYMBOL}", flush=True)

model = get_model()
df0   = load_dataset()

# robust counter of historical trades
total_trades_seen = int((df0["entered"] == 1).sum()) if "entered" in df0.columns else 0
open_tickets      = {}

while True:
    print(f"{datetime.now()}: Loop tick", flush=True)
    try:
        signal, candle, trend_candle, atr_median = raw_signal()
        if candle is None:                           # feed glitch
            time.sleep(LOOP_SECONDS)
            continue

        feat = build_features(candle, trend_candle, atr_median)
        feat["had_signal"] = int(signal is not None)
        feat["entered"]    = 0                       # default label

        # ----- probability from model ----------------------------------------
        prob = 0.5
        if model is not None:
            cols = ["hour","candle_size","ema_distance","atr","adx",
                    "volume","trend_above_ema","range_status","volatility_level"]
            prob = model.predict_proba(pd.DataFrame([feat])[cols])[0, 1]

        use_filter = (model is not None) and (total_trades_seen >= WARMUP_TRADES)
        accept     = (prob >= THRESHOLD) if use_filter else True
        # ---------------------------------------------------------------------

        if signal and accept:
            tick  = mt5.symbol_info_tick(SYMBOL)
            price = tick.ask if signal == "BUY" else tick.bid
            sl_pts = feat["ema_distance"] * 2
            tp_pts = sl_pts * RR_RATIO

            req = {
                "action": mt5.TRADE_ACTION_DEAL,
                "symbol": SYMBOL,
                "volume": LOT_SIZE,
                "type"  : mt5.ORDER_TYPE_BUY if signal == "BUY"
                                              else mt5.ORDER_TYPE_SELL,
                "price" : price,
                "sl"    : price - sl_pts if signal == "BUY" else price + sl_pts,
                "tp"    : price + tp_pts if signal == "BUY" else price - tp_pts,
                "deviation": 20,
                "magic"    : 0,
            }
            res = mt5.order_send(req)
            if res.retcode == mt5.TRADE_RETCODE_DONE:
                feat["entered"] = 1
                open_tickets[res.order] = feat
                print(f"{datetime.now()}: {signal} ticket={res.order} "
                      f"prob={prob:.2%}", flush=True)
            else:
                print(f"{datetime.now()}: send error {res.retcode}", flush=True)

        # ---------------------------------------------------------------------
        # **FIXED** — always save the bar, even if a trade was just sent
        save_feature(feat)
        # ---------------------------------------------------------------------

        # ---- monitor open trades -------------------------------------------
        for ticket in list(open_tickets):
            deals = mt5.history_deals_get(ticket=ticket)
            if deals:
                profit = deals[0].profit
                trade_feat = open_tickets.pop(ticket)
                trade_feat["outcome"] = int(profit > 0)
                save_feature(trade_feat)
                total_trades_seen += 1
                print(f"{datetime.now()}: ticket {ticket} closed P/L={profit:.2f}",
                      flush=True)

                if total_trades_seen % RETRAIN_EVERY == 0:
                    df = load_dataset()
                    model = train_model(df)
                    if model:
                        print(f"{datetime.now()}: model retrained on "
                              f"{int((df['entered'] == 1).sum())} trades", flush=True)

    except Exception as e:
        print(f"{datetime.now()}: runtime error – {e}", flush=True)

    # align to the next exact 5-minute multiple
    time.sleep(LOOP_SECONDS - (time.time() % LOOP_SECONDS))
