NF

地方で働くプログラマ

この前のコンテストのD問題の解説を書く 第3回(最終回)

この前のコンテストのD問題の解説を書く 第0回 - NF
この前のコンテストのD問題の解説を書く 第1回 - NF
この前のコンテストのD問題の解説を書く 第2回 - NF


・全編書き終わったらまとめて1記事にします。
・あっとこで言う灰色~茶色を対象に書きます。(ABCのABは解けるけどCが解けない人とか)
・言語はC++です(C++11機能まで使います)

目次

  • 問題の解説
  • 愚直に解いてみよう
  • 大きい数のmodを取る(割り算をする)には
  • 計算量を減らす為に(1) ←今日はここ
  • 計算量を減らす為に(2)
  • まとめ  ←今日はここ

前回の続き

めちゃくちゃ日が空いてしまい反省。TLEが取れずそのまま忘れてたのが原因です。

前回は大きい数字(10^18以上)の計算方法を調べながらサンプル通る事を確認しましたが、入力によってはTLEしてしまいました。今回は、1.初回計算を高速化する(解説記載の内容)、2.計算結果を使いまわす、3.0除算に注意する、の3本立てで記載します。目次ではもう一回予定してましたが、一回でまとまりそうなので最終回です。

1.初回計算を高速化する

公式の解説PDFに書かれてる内容です。(逆に、PDFはいきなりここから始まるので、自分は前回・前々回に書いた内容が分からず困ってました)

高速化のポイントは、各整数iについての計算内容はそのほとんどが重複していて、毎回計算しなくていいという事です。各iに対する実際の計算内容を並べてみると、この点に気付けます。


51(2進数で110011)を例に、初回の計算を説明します。入力は「6 110011」のパターンです。
i桁目を反転してf(n)を計算するので、以下の計算をすることになります。

010011 mod3 =            2^4 mod3                       + 2^1 mod3 + 2^0 mod3 
100011 mod3 = 2^5 mod3                                  + 2^1 mod3 + 2^0 mod3 
111011 mod5 = 2^5 mod5 + 2^4 mod5 + 2^3 mod5            + 2^1 mod5 + 2^0 mod5 
110111 mod5 = 2^5 mod5 + 2^4 mod5            + 2^2 mod5 + 2^1 mod5 + 2^0 mod5 
110001 mod3 = 2^5 mod3 + 2^4 mod3                                  + 2^0 mod3 
110010 mod3 = 2^5 mod3 + 2^4 mod3                       + 2^0 mod3

という訳で1bitしか変わらないので、元の値が4bitの場合必ずmod5かmod3の計算になります。更に、反転する桁の2^i modXを足し引きするだけで初回の計算ができます。これで、元の「110011」に対するmod5とmod3を事前に計算(これはO(NlogNぽいです))しておけば、桁数N×桁数Nの計算が桁数Nの計算だけでいけます。
また、2^i modXも最初に全桁計算して配列に入れておけば、以降使いまわせます。


例えば110011 mod5は1なので、上から3bitを反転した「111011」のmod5は以下のように計算します。

// 110011 mod5 = 2^5 mod5 + 2^4 mod5 +                     + 2^1 mod5 + 2^0 mod5 
// 111011 mod5 = 2^5 mod5 + 2^4 mod5 + 2^3 mod5            + 2^1 mod5 + 2^0 mod5 

110011 mod5 = 1
111011 mod5 = (110011 mod5) + 2^3 mod5 = 1 + 3 = 4


コードに落とします。型とかが適当になってるのは許してください。
ちなみに、1回目の計算でnが十分小さくなっている(≦200000)ので、2回目以降の計算(while文の部分)ではdivpowを使わなくてよいです。折角実装したけど。

std::map<int, long long> mp1;  // 各桁の2^i mod(hibit+1) の結果
std::map<int, long long> mp2;  // 各桁の2^i mod(hibit-1) の結果
unsigned long ans1 = 0;  // X mod (hibit+1) に対する初回計算結果
unsigned long ans2 = 0;  // X mod (hibit-1) に対する初回計算結果

unsigned long f(std::bitset<200000>& bs, long long index)
{
    long long n = -1;
    long long count = 0;
    
    // 初回だけmpを使って計算
    int hibit = bs.count();
    if(bs[index] == 0){
        hibit++;
        n = ans1 + mp1[index];
    }else{
        hibit--;
        n = ans2 - mp2[index];
    }
    if(n >= hibit){
        n = n % hibit;
    }

    // f(n)の2回目以降の計算
    while(ret != 0){
        std::bitset<200000> tmpbs(n);
        n = n % tmpbs.count();
        count++;
    }
}

int main()
{
    std::bitset<200000> bs(X);
    const int count = bs.count();
    
    // 前処理
    for(long long i=N-1; i>=0; i--){  // 制約から1<=Nなので、count=0のケースは考慮不要
        mp1[i] = pow(2, i, count+1);
        mp2[i] = pow(2, i, count-1);
        if(bs[i] == 1){
            ans1 += mp1[i];
            ans2 += mp2[i];
        }
    }

    // 本処理
    for(long long i=N-1; i>=0; i--){
        unsigned long result = f(bs, i);
        …


計算量を減らせましたが、まだ一部でTLEが残ります。公式解説通りに実装したのにどうして…
これはちゃんと原因分かってないので、自分への今後の課題とします。他の人のACコードをもう少し眺めてみようかと思います。

結果から言うと思いつく範囲で計算回数を減らしたところ、なんとかACすることができました。そこまで効くと思わずに試してみたので、ちょっと勉強になったので書きます。では次の項へ。

2.計算結果を使いまわす

2つの改善をしました。

1つが、関数f()の2回目以降の計算(while文の部分)で、nに対する計算結果を保存して、次以降同じnには使いまわすようにしました。これはmapの検索が入るので逆に遅くなるのでは?と思いましたが、試すと幾つかのTLEが回避できるようになりました。

    // 計算量削減のため初回計算後のnとそれに対する計算回数を記憶しておく
    std::map<int, long long>::iterator itr = results.find(n);
    if(itr != results.end()){
        return itr->second;
    }
    int tmp = n;
    while(n != 0){
        std::bitset<200000> tmpbs(n);
        n = n % tmpbs.count();
        count++;
    }
    results.emplace(tmp, count);

2つ目は、初回計算に使う入力X中の「1」の個数(=popcount(X))の計算を使いまわすようにしました。具体的には、main()とf()の中でそれぞれ計算していたので、main()で1回だけ計算してf()には引数で渡すようにしました。これで最大200000回分bitset::countを省略できますが、これが結構効果が高かったです。もちろんf()で毎回計算するのがムダなのは分かっていたのですが、計算コストは大したことないと思い込んでいました。。

ちなみに、C++のリファレンス等見てもbitset::countの計算量は明記されておらず。(宿題)

unsigned long f(std::bitset<200000>& bs, int hibit, long long index)
{
    …

int main()
{
    const int count = bs.count();
    …
    // 本処理
    for(long long i=N-1; i>=0; i--){
        unsigned long result = f(bs, count, i,);
        …


これで全てのTLEは取れたと思います。
しかし、当初の自分の実装だとRE(実行時エラー)が幾つかのテストケースで発生しました。これは実装を注意してしてれば防げたので本質的な話ではないですが、補足として説明しておきます。

3.0除算に注意する

テストケースの「RE」について色々考えたところ、0除算が発生するケースに気付きました。分かり易いのだと「100000」はhibitが1つですが、1.の解説の通り計算する前処理でmod2とmod0を計算することになり、後者の場合0除算が発生します。

今回の場合はmod0のケースは単に0を返せばいいので、気づけば対策はできました。
ソースコードに「X / Y」とか「X % Y」がある場合は要注意です。


という訳で、完成版のソースコードは如何になります。これでACできました。
書いてませんでしたが、powもp=0の時に動かない問題があったので、別途処理を入れてます。

#include <iostream>
#include <string>
#include <bitset>
#include <map>
using namespace std;
 
std::map<int, long long> mp1;
std::map<int, long long> mp2;
unsigned long ans1 = 0;
unsigned long ans2 = 0;
 
long long pow(long long n, long long p, long long m)
{
    if(m == 0) return 0;
    if(p == 0) return (1 % m);
    long long ret = 1;
    for (; p > 0; p >>= 1, n = n * n % m){
		if (p%2 == 1){
			ret = ret * n % m;
		}
	}
	return ret;
}
 
std::map<int, long long> results;
 
unsigned long f(std::bitset<200000>& bs, int hibit, long long index)
{
    long long n = -1;
    long long count = 0;
    
    // 初回だけmpを使って計算
    //int hibit = bs.count();
    if(bs[index] == 0){
        hibit++;
        n = ans1 + mp1[index];
    }else{
        hibit--;
        if(hibit <= 0) return count;
        n = ans2 - mp2[index];
        if(n < 0) n += hibit;
    }
    if(n >= hibit){
        n = n % hibit;
    }
    count++;
 
    // 計算量削減のため初回計算後のnとそれに対する計算回数を記憶しておく
    std::map<int, long long>::iterator itr = results.find(n);
    if(itr != results.end()){
        return itr->second;
    }
    int tmp = n;
    while(n != 0){
        std::bitset<200000> tmpbs(n);
        n = n % tmpbs.count();
        count++;
    }
    results.emplace(tmp, count);
 
    // nが0になるまでの操作回数を返す
    return count;
}
 
int main()
{
    long long N;
    std::cin >> N;
    std::string X;
    std::cin >> X;
 
    // 前準備
    std::bitset<200000> bs(X);
    const int count = bs.count();
 
    // 制約から1<=Nなので、count=0のケースは考慮不要
    for(long long i=N-1; i>=0; i--){
        mp1[i] = pow(2, i, count+1);
        mp2[i] = pow(2, i, count-1);
        if(bs[i] == 1){
            ans1 += mp1[i];
            ans2 += mp2[i];
        }
    }
 
    // ここから本処理
    for(long long i=N-1; i>=0; i--){
        unsigned long result = f(bs, count, i);
        std::cout << result << std::endl;
    }
}

まとめ

足掛け1カ月になってしまいました。想定読者は自分だけなのでいいのだけど。
最初に書いたとおり本問題は色々知らないアルゴリズムが知れてよかったです。解説書く事でちゃんと理解できるのもよかったので、またたまにやりたいです。本当はQiitaとかに書いた方が情報が集約される(ここ誰も見ないので…)んですが、なんか高レートじゃない人が解説とか…って言う人もいるらしい(実際に観測した)ので、ここでまったりやるのがいいのかも。


もし読んでくれた方がいたら、ここ分かり辛いとか感想とかご指摘いただけると嬉しいです。

それでは。