NF

地方で働くプログラマ

この前のコンテストのD問題の解説を書く 第2回

この前のコンテストのD問題の解説を書く 第0回 - NF
この前のコンテストのD問題の解説を書く 第1回 - NF


・全編書き終わったらまとめて1記事にします。
・あっとこで言う灰色~茶色を対象に書きます。(ABCのABは解けるけどCが解けない人とか)
・言語はC++です(C++11機能まで使います)

目次

  • 問題の解説
  • 愚直に解いてみよう
  • 大きい数のmodを取る(割り算をする)には ←今日はここ
  • 計算量を減らす為に(1)
  • 計算量を減らす為に(2)
  • まとめ

前回の続き

入力例1がACできたので、入力例2もやってみます。前回の最後に貼ったコードは入力例1の3bit用に実装していましたが、例2の入力は23bitですが、この後の事も考えて入力制約の200000桁入るよう、2か所のbitset<3>をbitset<200000>に変更してから試してください。
ちゃんと出力例2のとおり結果が出ると思います。

大きい数の割り算をする

繰り返し自乗法

さて、まだ終わりじゃないです。このまま提出すると、サンプルの2ケース以外はほとんどWAとかREになると思います。原因の1つとして、現状のf(n)の実装はunsigned long(32bit)で計算(割り算)をしていますが、Nの制約が最大200000なので、変数のサイズが全然足りません(拡張で128bitとかも使えますが全然足りません)。2進数で200000桁の値を表すには、当たり前ですが200000bitが必要になります。

大きい数の割り算をどうやるか?今回は、「繰り返し二乗法」という方法を使います。「繰り返し二乗法」の解説は検索すると沢山出てきますが、簡単に言うと2の8乗が
2^8 = 2^4 × 2^4
2^4 = 2^2 × 2^2
2^2 = 2^1 × 2^1
のように分解すると計算回数を減らせるというものです。競プロでは計算回数を減らして高速化するために用いられる事が多いらしいですが、大きい数の計算にも応用できます。
参考:繰り返し二乗法で n^p % mod を拘束に求める

なお、ここでは「繰り返し二乗法」自体については詳しく掘り下げないです。これを使えば計算できるよ!って事だけを書きます。いきなり色々覚えようとするとツライので…


nのp乗をmで割り算して余りを返す関数pow()を実装します。こんな感じなります。
例えば2の3乗を3で割ると8%3で2になると思います。pow(2, 3, 3)を実行するとちゃんと2が返されます。普通では計算ができない2の200000乗割る2とかも計算できます。

long long pow(long long n, long long p, long long m)
{
    long long ret = 1;
    for (; p > 0; p >>= 1, n = n * n % m){
		if (p%2 == 1){
			ret = ret * n % m;
		}
	}
	return ret;
}
合同式

さて、これで「2^p割るm」は計算できるようになりました。では「11」のように「2^p」で表せない数はどう計算するのか?これは「合同」という考え方が使えるようです。「合同」とは一言でいうと、「割り算の余りのみに注目した等式」だそうです。
参考:合同式の意味とよく使う6つの性質

例えば、7割る3(7mod3)は1ですが、4mod3も1です。この時、「7≡4」と書きます。記号「≡」は「ごうどう」で変換すると出てきました。同様に、8mod3は2、5mod3も2なので、「8≡5」ということになります。


合同の性質の1つに、「合同式は辺々足し算できる」というのがあります。先ほどの例で言うと、mod3を計算する場合において「7≡4」、「8≡5」となるので、右辺と左辺を足して「15≡9」となります。本当か?という方は計算してみると分かりますが、15mod3は0、9mod3も0です。

この性質を利用すると、「2^p」で表せない数が計算できます。例えば「11」は8+2+1であり、2^3+2^1+2^0に分解できますが、合同の性質を使うと「11mod3」は「2^3 mod3+2^1 mod3+2^0 mod3」に言い換えらえます。つまり「2^p」の和の形に変換できるので、各項をそれぞれ先ほど実装したpow()を使ってから足せばいいことになります。実際にやると、8mod3=2、2mod3=2、1mod3=1なので、2+2+1=5です。5はmod値の3より大きいので、5に対して再度同じ計算をします。2進数で101なので、4mod3+1mod3=2となり、これが11mod3の答えになります。


N桁の2進数に対してmod取る関数divpow()を実装してみます。引数には、bitsetで表現した2進数と桁数N、mod値を渡して、割り算した余りを返すようにします。

template<size_t SIZE>
long long divpow(std::bitset<SIZE> bs, long long N, long long m)
{   
    long long ret;
    do{
        ret = 0;
        for(long long i=N-1; i>=0; i--){
            if(bs[i] == 1){  // ビットが立っている場合はpowを計算
                ret += pow(2, i, m);
            }
        }
        std::bitset<SIZE> tmpbs(ret);
        bs = tmpbs;
    }while(ret >= m);  // 余りがm未満になるまで繰り返す
    
    return ret;
}

C++のstd::bitsetのSIZEは定数でないとならないのでテンプレートにしましたが、特に気にしなくて大丈夫です。「11」(2進数で1011)を以下のように計算させると、「2」が出力されます。

    std::bitset<200000> bs("1011");
    std::cout << divpow(bs, 4, 3) << std::endl;
f(n)をdivpowを使うよう修正

f(n)をdivpowを使って実装し直します。
divpow()も少し変更して、powの結果を余りを取りなら計算します。これも合同の性質ですね。
1回あたりのpowの結果はmよりは小さくなるので、この計算はintでできます。

template<size_t SIZE>
long long divpow(std::bitset<SIZE> bs, long long N, long long m)
{   
    long long ret = -1;
    do{
        ret = 0;
        for(long long i=N-1; i>=0; i--){
            if(bs[i] == 1){  // ビットが立っている場合はpowを計算
                ret += pow(2, i, m);
                if(ret <= m){
                    ret = ret % m;   // 計算しなが余りを取る
                }
            }
        }
        std::bitset<SIZE> tmpbs(ret);
        bs = tmpbs;
    }while(ret >= m);  // 余りがm未満になるまで繰り返す
    
    return ret;
}

template<size_t SIZE>
int popcount(std::bitset<SIZE> bs)
{
    int count = bs.count();
    return count;
}
    
template<size_t SIZE>
unsigned long f(std::bitset<SIZE> bs, long long N)
{
    long long ret = -1;
    long long count = 0;
    do{
        // nをpopcount(n)で割ったあまりに置き換える
        // n = n % popcount(n);
        int hibit = popcount(bs);
        ret = divpow(bs, N, hibit);
                     
        // 操作回数をカウント
        count++;
        
        // divpowの戻り値を次の計算対象に設定
        std::bitset<SIZE> tmpbs(ret);
        bs = tmpbs;
    }while(ret != 0);
    
    // nが0になるまでの操作回数を返す
    return count;
}

今回のまとめ

ここまでの内容で問題が解けそうですが、実はまだ無理です。試しに提出してみると、TLE(時間切れ)になると思います。
例えば制約上一番大きい20000桁が与えられた場合、これまでの実装だとf(n)を200000回計算することになりますが、制約時間の「2sec」では終わりません。どれくらい足りないの?って言われるとちゃんと答えられないですが、少なくとも今回の実装例だとpowがO(logN)(らしい)、divpowがN回ループなのでO(N)、f(n)がO(N)なので、全体でO(N×N×logN)となります。制約が2×10^5なので10^9を余裕で超えてしまいそうです。

という訳で、次回は計算量を減らす工夫について書きます。


ーーー
解説(のようなもの)、書きなれてないので長くなってしまいますね…。あと図とかないと厳しいかも。

あとぐぐると繰り返し「二乗法」と「自乗法」」が引っかかって前者の方が多いんだけど、やってる事的に「自乗法」の方が正しいのでは…