Search on the blog

2015年8月4日火曜日

はじめてのMonte Carlo Tree Search

 以前から実装してみたかったMonte Carlo Tree Search(MCTS)を実装してみました。MCTSはその名のとおり、モンテカルロな木探索です。主にゲーム木の探索に用いられます。

 普通の全探索では木の節点すべてを調べますが、MCTSでは根から葉へのパスをいくつかランダムに選んで探索を行います。

 まず1度、根から葉へのパスを調べたとします。これはランダムにゲームを終局まで進めたことに対応します。これによりパス上の節点において、選んだ手を打ったときの勝率が概算できます。もちろん1回ランダムにゲームをしただけなので精度は低いです。
2回目以降も同様にランダムに根から葉へのパスを選びゲームを進めていきます。回数を重ねるごとに情報が溜まってきて、各ノードからどのノードに行くのが良さそうか分かってきます。

 これだけだと精度を上げるのに時間がかかるため、子ノードの選択を工夫します。
  1. 勝率の高そうな子ノードを重点的に選択する
  2. 訪問数の少ない子ノードも積極的に開拓する
1.と2.はトレードオフの関係にあるため、どちらを優先するかはパラメータの設定で決めます。

例題
まずは簡単な例題で実装してみました。
「最高3つカウントダウン出来て、1を言った人が勝ち」というゲームを考えます。
以下にプレイヤーAとBの対戦例を示します。

初期値: 10
A: 10, 9
B: 8
A: 7,6,5
B: 4, 3
A: 2, 1

でAの勝ちです。このゲームの最善手は動的計画法を使えば分かりますが(もしくは法則を知っていれば”4k+1を取れば勝ち”で終わり)、敢えてMCTSで解いてみます。

ソースコード
最近Pythonが多いので、久しぶりにJavaで書いてみました。 実行するとそれっぽい結果が得られました。ある状態で自分の手番が回ってきたときの勝率と最善手を計算できます。 i=19まではそれっぽいですが、i=20以降はダメです(あれ・・、全探索した方が速いような・・)。ロールアウト回数を増やすと、計算時間は増えますが大きなiでもそれっぽい結果になるはずです。
package com.kenjih.mcts;

import java.util.HashMap;
import java.util.Map;
import java.util.Stack;

public class MCTS {
    
    private static final int MAX_NUM = 3;
    private int rollOut;
    private Node root;
    
    public MCTS(int initState, int rollOut) {
        this.rollOut = rollOut;
        this.root = new Node(initState);
    }

    public int getNextBestHand() {
        
        for (int _ = 0; _ < rollOut; _++) {
            Node node = root;
            
            Stack<Node> stack = new Stack<Node>();
            while (!node.isLeaf()) {
                stack.add(node);
                node = node.expand().select();
            }
            stack.add(node);
            
            int win = 1;   // 1:win, 0:lose
            while (!stack.empty()) {
                node = stack.pop();
                ++node.n;
                node.w += win;
                win ^= 1;
            }
        }
        
        int ret = -1;
        double bestRate = getBestWinRate();
        
        for (int i : root.children.keySet()) {
            double rate = root.children.get(i).getWinRate();
            if (rate == bestRate) {
                ret = i;
                break;
            }
        }
        
        return ret;
    }
    
    public double getBestWinRate() {
        double ret = -1.0;
        
        for (int i : root.children.keySet()) {
            double rate = root.children.get(i).getWinRate();
            ret = Math.max(ret, rate);
        }
        
        return ret;        
    }
    
    class Node {
        int w;      // # of wins
        int n;      // # of visits
        int state;  // game state (current number in this case)
        Map<Integer, Node> children = null;  // hand -> next state
        
        Node(int state) {
            this.state = state;
            this.w = 0;
            this.n = 1;    
        }
        
        boolean isLeaf() {
            return state == 0;
        }
        
        Node expand() {
            if (children == null) {
                children = new HashMap<Integer, MCTS.Node>();
                
                for (int i = 1; i <= MAX_NUM; i++) {
                    if (state - i >= 0) {
                        children.put(i, new Node(state - i));
                    }
                }
            }
            
            return this;
        }
                
        Node select() {
            double bestScore = -1.0;
            Node ret = null;
            
            for (int hand : children.keySet()) {
                Node nxt = children.get(hand);
                double score = nxt.getScore(n);
                if (score > bestScore) {
                    bestScore = score;
                    ret = nxt;
                }
            }
            return ret;
        }
        
        public double getWinRate() {
            return 1.*w/n;
        }
        
        private double getScore(int total) {
            double c = Math.sqrt(2);
            double t = Math.log(total);
            return 1.*w/n + c*Math.sqrt(t / n);
        }
        
    }
    
    public static void main(String[] args) {
        
        for (int i = 1; i < 50; i++) {
            MCTS mcts = new MCTS(i, 10000);
            int ret = mcts.getNextBestHand();
            double rate = mcts.getBestWinRate();
            System.out.println(i + "->" + ret + ": " + rate);    
        }        
                    
    }

}

0 件のコメント:

コメントを投稿