def main():
    # Load Densenet model
    densenet_model_path = 'C:\\\\2024\\\\On-Campus Activities\\\\Capstone\\\\2024_Capstone\\\\Capstone Project\\\\code\\\\densenet_model70.pth'
    densenet_model = load_densenet(densenet_model_path)

    # Load YOLO model
    yolo_model_path = 'C:\\\\2024\\\\On-Campus Activities\\\\Capstone\\\\2024_Capstone\\\\Capstone Project\\\\code\\\\5_10_yolo_model_(epoches50_batch8_data_renew_5class).pt'
    yolo_model = YOLO(yolo_model_path)

    # Load video
    video_path = 'C:\\\\2024\\\\On-Campus Activities\\\\Capstone\\\\2024_Capstone\\\\Capstone Project\\\\code\\\\cc_3_220305_vehicle_252_47329 (1).mp4'
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error opening video file.")
        return

    # Set desired resolution
    desired_width = 1920  # Example width
    desired_height = 1080  # Example height
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, desired_width)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, desired_height)

    fps = int(cap.get(cv2.CAP_PROP_FPS))
    desired_fps = 15
    frame_skip = round(fps / desired_fps)
    frame_count = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        if frame_count % frame_skip != 0:
            frame_count += 1
            continue

        # Predict collision using Densenet model
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        collision = predict_collision(frame, densenet_model, device)
        print(collision)
        # If collision is detected, pass the frame through YOLO model
        if collision == 1:
            print("Collision detected! Running YOLO model...")
            results = yolo_model(frame)
            print(results)
            if isinstance(results, list) and len(results) > 0:
                for det in results:
                    if isinstance(det, np.ndarray) and det.shape[0] == 6:  
                        x1, y1, x2, y2, conf, cls_id = map(int, det[:4]) + [det[4], int(det[5])]
                        if cls_id < 4:  # Assuming collision class numbers are from 0 to 3
                            print("Object detected by YOLO within collision area:", yolo_model.names[cls_id])  # 수정된 부분
                            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                            cv2.putText(frame, yolo_model.names[cls_id], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)  # 수정된 부분
            else:
                print("No objects detected by YOLO within collision area.")

        cv2.imshow('Collision Detection', frame)
        if cv2.waitKey(1) == ord('q'):
            break

        frame_count += 1

    cap.release()
    cv2.destroyAllWindows()

if __name__ == "__main__":
    main()