Problem Solving/항해99

99클럽 코테 스터디 26일차 TIL + Backtracking, Binary Search

wrathlion 2024. 11. 23. 01:57

문제

Programmers: 주사위 고르기

  • 설명: $n$개의 주사위가 주어진다. 이때 두 사람이 반씩 나눠가진다고 할 때 A가 승리할 확률이 가장 높아지도록 주사위를 분배하여라.

풀이

접근 1

  • 가장 나이브하게 접근해보자
    • 주사위를 선택한다.
    • 선택된 모든 주사위의 합을 구한 뒤, 각 합들 쌍을 비교하며 승/패/무 비율을 구한다.
  • 시간복잡도
    • 주사위 선택: $\binom{10}{5}$
    • 모든 주사위의 합을 구하기: $6^5$
    • 각 합들의 쌍을 구한다: $(6^5)^2$

접근 2

  • 시간을 좀 더 줄여보자 각 합들의 쌍을 구하는 것을 이진탐색으로 해보자
    • 즉, 다른 사람의 주사위의 합들에서 내가 비교하려는 숫자와 동일한 숫자, 작은 숫자, 큰 숫자를
    • lower, upper_bound를 통해 쉽게 구할 수 있다.
  • 그렇다면 각 합들의 쌍을 구하는 연산을 $6^5 \log{6^5}$로 줄일 수 있다.

오늘의 회고

  • Python에서 이진 탐색 구현체를 처음 사용해보았다.
      from bisect import bisect_left, bisect_right
      bisect_left(list, target)
    • bisect_leftlower_bound와 동일한 결과를, bisect_rightupper_bound와 동일한 결과를 반환한다.

Code

# Programmers258790.py
from bisect import bisect_left, bisect_right
from copy import deepcopy

def get_dice_sum_list(dice, dice_list, team_list):
    n = len(dice_list)

    def solve(idx, curr):
        if idx >= n:
            team_list.append(curr)
        else:
            for i in range(6):
                solve(idx + 1, curr + dice[dice_list[idx]][i])
    solve(0, 0)

def solution(dice):
    n = len(dice)
    half = n // 2
    isTeam1 = [False] * n

    ans = (0, [])
    def select(team1, team2, idx):
        nonlocal ans
        if len(team1) == half and len(team2) == half:
            team1_score, team2_score = [], []
            get_dice_sum_list(dice, team1, team1_score)
            get_dice_sum_list(dice, team2, team2_score)

            team1_score.sort()

            win, lose, draw = 0, 0, 0
            for target in team2_score:
                low = bisect_left(team1_score, target)
                high = bisect_right(team1_score, target)

                lose += low
                draw += high - low
                win += len(team1_score) - high

            if ans[0] < win:
                ans = (win, deepcopy(team1))

        else:
            if len(team1) < half:
                team1.append(idx)
                select(team1, team2, idx + 1)
                team1.pop()

            if len(team2) < half:
                team2.append(idx)
                select(team1, team2, idx + 1)
                team2.pop()

    select([], [], 0)

    for i in range(len(ans[1])):
        ans[1][i] += 1

    return ans[1]