49 lines
1.9 KiB
Python
49 lines
1.9 KiB
Python
from typing import Any, Dict, List
|
|
from common.base_problem import BaseProblem
|
|
from problems.n_queens.generator import generate_n_queens_case
|
|
from problems.n_queens.solvers import solve_backtracking
|
|
from common.timer_utils import time_execution
|
|
|
|
class NQueensProblem(BaseProblem):
|
|
def generate_case(self, n: int = 8) -> Dict[str, Any]:
|
|
return generate_n_queens_case(n)
|
|
|
|
def solve(self, input_data: Dict[str, Any], algorithm: str) -> Dict[str, Any]:
|
|
n = input_data.get("n", 8)
|
|
result = None
|
|
duration = 0.0
|
|
|
|
if algorithm == "backtracking":
|
|
# For visualization, we usually just want one solution,
|
|
# but for benchmarking maybe we want to count all?
|
|
# Let's default to finding ONE solution for "solve" to be fast enough for UI
|
|
# If N is small (< 12), we can find all.
|
|
find_all = n < 10
|
|
result, duration = time_execution(solve_backtracking)(n, find_all)
|
|
else:
|
|
raise ValueError(f"Unknown algorithm: {algorithm}")
|
|
|
|
return {
|
|
"algorithm": algorithm,
|
|
"solution_count": len(result),
|
|
"first_solution": result[0] if result else None,
|
|
"time_seconds": duration
|
|
}
|
|
|
|
def verify(self, input_data: Any, result: Any) -> bool:
|
|
# Verification is tricky without a ground truth.
|
|
# We can check if the returned solution is valid.
|
|
sol = result.get("first_solution")
|
|
if not sol:
|
|
return True # No solution found (could be valid for some N, though N>=4 always has sol)
|
|
|
|
n = len(sol)
|
|
for r in range(n):
|
|
c = sol[r]
|
|
# Check conflicts
|
|
for prev_r in range(r):
|
|
prev_c = sol[prev_r]
|
|
if c == prev_c or abs(r - prev_r) == abs(c - prev_c):
|
|
return False
|
|
return True
|