# 檔案 1: 後端伺服器 (請存成 app.py)
# -------------------------------------
# 這個檔案會啟動一個本地伺服器，等待前端網頁的請求。

import flask
from flask import request, jsonify
import cv2
import numpy as np
import torch
from ultralytics import YOLO
import io
import base64
from PIL import Image
import tempfile
import os

# --- 您的原始分析函式 (幾乎無改動) ---
def crop_with_mask(img, mask, bbox_xyxy, pad):
    H, W = img.shape[:2]
    x1, y1, x2, y2 = map(int, bbox_xyxy)
    x1 = max(0, x1 - pad); y1 = max(0, y1 - pad)
    x2 = min(W, x2 + pad); y2 = min(H, y2 + pad)
    sub = img[y1:y2, x1:x2].copy()
    submask = mask[y1:y2, x1:x2].astype(bool)
    mean_color = sub.mean(axis=(0,1), keepdims=True)
    sub[~submask] = mean_color
    return sub, (x1, y1, x2, y2)

def prep_rgb01(img_bgr, size, device):
    rgb = cv2.cvtColor(cv2.resize(img_bgr, (size, size)), cv2.COLOR_BGR2RGB).astype(np.float32)/255.0
    t = torch.from_numpy(rgb).permute(2,0,1).unsqueeze(0).to(device)
    return rgb, t

def get_input_size(cla, default=224):
    imgsz = None
    if hasattr(cla, "overrides") and isinstance(cla.overrides, dict):
        imgsz = cla.overrides.get("imgsz", None)
    if imgsz is None and hasattr(cla, "model"):
        args = getattr(cla.model, "args", None)
        if isinstance(args, dict): imgsz = args.get("imgsz", None)
        elif hasattr(args, "imgsz"): imgsz = args.imgsz
    if imgsz is None: imgsz = default
    if isinstance(imgsz, (list, tuple)): imgsz = imgsz[0]
    return int(imgsz)

def forward_logits(cla_net, t):
    out = cla_net(t)
    if isinstance(out, (list, tuple)): out = out[0]
    if isinstance(out, dict):
        for k in ("logits","pred","out","y"):
            if k in out: out = out[k]; break
    return out

def saliency_overlay(cla_net, roi_bgr, pred_id, in_size):
    device = next(cla_net.parameters()).device
    rgb01, t = prep_rgb01(roi_bgr, in_size, device)
    t.requires_grad_(True); t.retain_grad()
    cla_net.zero_grad(set_to_none=True)
    logits = forward_logits(cla_net, t)
    score = logits[0, int(pred_id)]
    score.backward()
    g = t.grad.detach().abs()[0]
    sal = g.max(dim=0)[0]
    sal -= sal.min()
    sal = sal / sal.max() if sal.max() > 0 else sal
    sal = sal.cpu().numpy()
    heat = cv2.applyColorMap((sal*255).astype(np.uint8), cv2.COLORMAP_JET)
    base = cv2.cvtColor((rgb01*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    return (0.5*base + 0.5*heat).astype(np.uint8)

def gradcam_or_saliency(cla_net, roi_bgr, pred_id, in_size):
    device = next(cla_net.parameters()).device
    rgb01, t = prep_rgb01(roi_bgr, in_size, device)
    t.requires_grad_(True)
    torch.set_grad_enabled(True)
    holder = {"feat": None, "grad": None, "bh": None}
    def fwd_hook(m, i, o):
        if isinstance(o, torch.Tensor) and o.dim() == 4:
            holder["feat"] = o
            if holder["bh"] is not None: holder["bh"].remove()
            def bwd_hook(mm, gin, gout):
                holder["grad"] = gout[0]
            holder["bh"] = m.register_full_backward_hook(bwd_hook)
    handles = [m.register_forward_hook(fwd_hook) for m in cla_net.modules()]
    cla_net.zero_grad(set_to_none=True)
    logits = forward_logits(cla_net, t)
    score = logits[0, int(pred_id)]
    score.backward()
    for h in handles: h.remove()
    if holder["bh"] is not None: holder["bh"].remove()
    if holder["feat"] is None or holder["grad"] is None:
        return saliency_overlay(cla_net, roi_bgr, pred_id, in_size)
    A = holder["feat"][0]
    G = holder["grad"][0]
    weights = G.mean(dim=(1,2))
    cam = torch.relu((weights[:, None, None] * A).sum(dim=0))
    cam -= cam.min()
    cam = cam / cam.max() if cam.max() > 0 else cam
    cam = cam.detach().cpu().numpy()
    cam = cv2.resize(cam, (in_size, in_size), interpolation=cv2.INTER_LINEAR)
    heat = cv2.applyColorMap((cam*255).astype(np.uint8), cv2.COLORMAP_JET)
    base = cv2.cvtColor((rgb01*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    return (0.5*base + 0.5*heat).astype(np.uint8)

def occlusion_map(cla, roi_bgr, pred_id, patch, stride):
    H, W = roi_bgr.shape[:2]
    base = float(cla.predict([roi_bgr], verbose=False)[0].probs.data[int(pred_id)])
    heat = np.zeros((H, W), dtype=np.float32)
    mean_color = roi_bgr.mean(axis=(0,1), keepdims=True)
    for y in range(0, H, stride):
        for x in range(0, W, stride):
            y2 = min(y + patch, H); x2 = min(x + patch, W)
            occl = roi_bgr.copy()
            occl[y:y2, x:x2] = mean_color
            score = float(cla.predict([occl], verbose=False)[0].probs.data[int(pred_id)])
            drop = max(0.0, base - score)
            heat[y:y2, x:x2] = np.maximum(heat[y:y2, x:x2], drop)
    if heat.max() > 0: heat /= heat.max()
    heat_color = cv2.applyColorMap((heat*255).astype(np.uint8), cv2.COLORMAP_JET)
    return (0.5*roi_bgr + 0.5*heat_color).astype(np.uint8)

def paste_overlay_to_full(full_img, overlay_crop, bbox_xyxy, pad):
    x1, y1, x2, y2 = map(int, bbox_xyxy)
    x1 = max(0, x1 - pad); y1 = max(0, y1 - pad)
    x2 = min(full_img.shape[1], x2 + pad); y2 = min(full_img.shape[0], y2 + pad)
    Ht, Wt = y2 - y1, x2 - x1
    if Ht <= 0 or Wt <= 0: return full_img
    over_rs = cv2.resize(overlay_crop, (Wt, Ht), interpolation=cv2.INTER_LINEAR)
    out = full_img.copy()
    alpha = 0.55
    out[y1:y2, x1:x2] = (alpha*over_rs + (1-alpha)*out[y1:y2, x1:x2]).astype(np.uint8)
    return out

# --- Flask 伺服器設定 ---
app = flask.Flask(__name__)

def numpy_to_base64(img_np):
    """將 OpenCV (Numpy) 格式的圖片轉換為 Base64 字串"""
    _, buffer = cv2.imencode('.png', img_np)
    return base64.b64encode(buffer).decode('utf-8')

@app.route("/analyze", methods=["POST"])
def analyze():
    """接收檔案和參數，執行分析並回傳 JSON 結果"""
    # 檢查檔案是否齊全
    if 'image' not in request.files or 'seg_model' not in request.files or 'cla_model' not in request.files:
        return jsonify({"error": "Missing files"}), 400

    # 讀取檔案至記憶體
    image_file = request.files['image'].read()
    seg_model_file = request.files['seg_model'].read()
    cla_model_file = request.files['cla_model'].read()

    # 讀取參數
    pad = int(request.form.get('pad', 12))
    occl_patch = int(request.form.get('occl_patch', 24))
    occl_stride = int(request.form.get('occl_stride', 12))

    # 將記憶體中的模型檔案寫入暫存檔，因為 YOLO.load 需要檔案路徑
    with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as seg_temp:
        seg_temp.write(seg_model_file)
        seg_model_path = seg_temp.name
    with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as cla_temp:
        cla_temp.write(cla_model_file)
        cla_model_path = cla_temp.name
        
    try:
        # 載入模型
        seg = YOLO(seg_model_path)
        cla = YOLO(cla_model_path)
        cla_net = cla.model.eval()
        
        # 將圖片從記憶體中解碼
        img = cv2.imdecode(np.frombuffer(image_file, np.uint8), cv2.IMREAD_COLOR)

        # --- 執行核心分析流程 ---
        sres = seg.predict(source=img, verbose=False)[0]
        if sres.masks is None or len(sres.boxes) == 0:
            return jsonify({"error": "No objects detected"}), 404

        H, W = img.shape[:2]
        masks_small = sres.masks.data.cpu().numpy()
        masks = np.stack([cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST) for m in masks_small], axis=0)

        crops, metas = [], []
        for i, box in enumerate(sres.boxes):
            xyxy = box.xyxy.cpu().numpy().reshape(-1).tolist()
            roi, (x1,y1,x2,y2) = crop_with_mask(img, masks[i], xyxy, pad)
            crops.append(roi)
            metas.append({
                "i": i,
                "seg_cls": int(box.cls.item()),
                "seg_conf": float(box.conf.item()),
                "bbox_xyxy": [int(x1), int(y1), int(x2), int(y2)]
            })

        cres_list = cla.predict(source=crops, verbose=False)
        in_size = get_input_size(cla, default=224)

        instance_results = []
        full_cam = img.copy()
        full_occ = img.copy()
        cla_names = cla.names

        for meta, cres, roi in zip(metas, cres_list, crops):
            probs = cres.probs
            cla_cls = int(probs.top1)
            cla_conf = float(probs.top1conf)
            cla_cls_name = cla_names.get(cla_cls, f"Class {cla_cls}")

            grad_bgr = gradcam_or_saliency(cla_net, roi, cla_cls, in_size)
            occ_bgr  = occlusion_map(cla, roi, cla_cls, occl_patch, occl_stride)

            full_cam = paste_overlay_to_full(full_cam, grad_bgr, meta["bbox_xyxy"], pad)
            full_occ = paste_overlay_to_full(full_occ, occ_bgr,  meta["bbox_xyxy"], pad)

            rec = {
                **meta,
                "cla_cls": cla_cls,
                "cla_cls_name": cla_cls_name,
                "cla_conf": round(cla_conf, 4),
                "crop_img": numpy_to_base64(roi),
                "gradcam_img": numpy_to_base64(grad_bgr),
                "occlusion_img": numpy_to_base64(occ_bgr)
            }
            instance_results.append(rec)
        
        # 準備回傳的 JSON 資料
        response_data = {
            "full_cam_img": numpy_to_base64(full_cam),
            "full_occ_img": numpy_to_base64(full_occ),
            "instance_results": instance_results
        }
        return jsonify(response_data)

    except Exception as e:
        return jsonify({"error": str(e)}), 500
    finally:
        # 清理暫存檔案
        os.remove(seg_model_path)
        os.remove(cla_model_path)

@app.route('/')
def index():
    # 這是一個簡單的路由，用來提供前端 HTML 檔案
    # 在實際部署中，前端檔案通常由 Nginx 等專用伺服器提供
    return flask.send_from_directory('.', 'index.html')


if __name__ == "__main__":
    # 啟動 Flask 伺服器
    # host='0.0.0.0' 讓區域網路內的其他裝置也能存取
    app.run(host='0.0.0.0', port=5000, debug=True)
