世界一速いナベアツに挑む①

この記事はプロコンゼミ(SPC同好会) その2 Advent Calendar 2018 - Adventarの14日目の記事です。

導入

このネタは、我が部活が誇るあかこう先輩のネタのパクリ継承となります。 元の記事は削除されてしまったようなので、参考になりそうなリンクを貼っておきます。

https://akakou-slide.github.io/nabeatsu/ github.com

ところで、ナベアツというのは皆さんご存知の世界のナベアツこと桂三度さんのネタのことです。

わからない人のために...

例えばAくんという人が居るとします。Aくんは自然数を数え上げていきます。 数え上げる自然数が3の倍数、もしくは3がつく数のときにおかしな声をだします(このことを、バカになるという)。

注意

  • この記事の内容を実際に試したことにより、システムにどのような損害が起きたとしても筆者は責任を負いかねます。試す際には自己責任で宜しくおねがいします。
  • 世界一速いナベアツを名乗っていますが、これは少し誇張した表現ですので訂正します。この記事では、x86_64上で動作するLinux 4.19において最も高速に動作するナベアツを目指します。

高速化のポイント

ソースコードはMITライセンスということで、ありがたく高速化のポイントの解説として、使わせていただきましょう。

MIT License

Copyright (c) 2017 akakou

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

一応ライセンス表示。

ということで、一旦あかこう先輩が書いたソースコードを全部のせます。

; nabeatsu.s
; 最強のNABEATSU☆を目指す


; コードセクション
section .text
    global _start                           ; _startを指名

; スタート
_start:
                                            ; カウント用レジスタの初期化
    mov r12, 1                              ; ボケるか判定用カウントレジスタ(普通にカウントする)
    mov r13, 1                              ; 表示用カウントレジスタ(BCDでカウントする)

; 3の倍数かをチェックする
CHECK_THREE_MULTIPLE:
    xor rdx, rdx
                                            ; r12を3で割る
    mov rax, r12                            ; r12/3(rbx) = rax...rdx
    mov rbx, 3
    div rbx

    cmp rdx, 0                              ; 除算の余りと0を比較する
    je SET_BOKE_FMT                         ; 0ならprint_abnomalにジャンプする

    mov r9,  r12                            ; r12(ボケるか判定用カウントレジスタ)をr9(3を含む数値か判定用レジスタ)にコピー

READY_CHECK_IN_THREE:
    mov r8, r13

; 3のつく数字化をチェックする
CHECK_IN_THREE:
    mov rbx, 0xF                            ; rbxに0xFを格納
    and rbx, r8                             ; r8と0xFに論理和をかけることで、16進数での1桁分取り出し、rbxに格納

    shr r8,  4                              ; 16進数におけるBCDでの1桁分右にシフトする。

    cmp rbx, 3                              ; シフトの余りと3を比較する
    je SET_BOKE_FMT                         ; 余りが0ならprint_abnomalにジャンプする

    cmp r8, 0                               ; 除算のあまりと0を比較し
    jne CHECK_IN_THREE                      ; 余りが0ならprint_abnomalにジャンプする

; 普通のフォーマット(数値のみを表示)をrbpに格納する。
SET_NOMAL_FORMAT:
    mov rbp, normal_message                 ; rbpに数値のみ表示するフォーマットを格納する

; 次の処理の準備 & `SET_NOMAL_FORMAT`からのジャンプバック用ラベル
READY_TO_SET_FORMAT_LOOP:
    mov rax, r13                            ; rax(今回表示する値が入ってるレジスタ)にr13(表示用カウントレジスタ)の値をコピー
    mov r15, 15                             ; r15に15(最大このプログラムの耐えられる最大桁数)
    add r15, rbp                            ; rbpの値(指定したフォーマットの文字列)と最大値を足した値をr15(次に格納する文字のアドレス先)に格納

; r13から1文字ずつ取り出す
SET_FORMAT_LOOP:
    dec r15                                 ; r15を1つ減算する
    mov rbx, 0xF                            ; rbxに0xFを格納
    and rbx, rax                            ; raxと0xFに論理和をかけることで、16進数での1桁分取り出し、rbxに格納
    add rbx, 0x30                           ; rbxに格納された1桁の数値に0x30を足すことで、ASCIIの文字のデータにする
    mov [r15], bl                           ; bl(rbxから1バイト分のデータ)をr15の持つアドレス先に書き込む
    shr rax, 4                              ; 16進数におけるBCDでの1桁分右にシフトする。
    jnz SET_FORMAT_LOOP                     ; raxが0になるまでSET_FORMAT_LOOPを回し続ける

; 文字列を表示
PRINT:
    mov r14, rcx                            ; rcxの値をr14に退避

    mov rcx, rbp                            ; メッセージのアドレス
    mov rdx, 17                             ; メッセージの長さ
    mov rbx, 1                              ; 標準出力を指定
    mov rax, 4                              ; システムコール番号 (sys_write)
    int 0x80                                ; 割り込み

    mov rcx, r14                            ; r14に退避させた値をrcxに戻す

; カウントをする
COUNT:
    inc r12                                 ; r12に1加算
    inc r13                                 ; r13に1加算

    mov r9, r13                             ; r9(繰り上がれてないか確認用レジスタ)にr13(表示用カウントレジスタ)をコピー
    mov r11, 0xF                            ; r11に0xFを格納

; 表示用カウントレジスタ(BCD)の桁上がりチェック
CARRY_UP:
                                            ; 各桁に0xAが含まれていないか確認し、含まれていたら`CARRY_UP_BCD`を呼ぶ
    mov r10, 0xF                            ; r10に0xFを格納
    and r10, r9                             ; r10にr9の右から1桁分を格納させる
    cmp r10, 0xA                            ; r10(上の処理で取った1桁分)と0xA(BCDではありえない値)かを確認
    je CARRY_UP_BCD                         ; r10が0xAなら繰り上がりをする

                                            ; 二重繰り上がりが起きたときの処理
    cmp r10, 0xB                            ; r10に0xBが格納されていないかを確認
    je CARRY_UP_BCD                         ; r10が0xBなら繰り上がりをする(0xC以上は理論上ありえない)

; 次のカウント後の処理へジャンプする
LOOP_BACK:
    cmp r12, 1000000                        ; 最大カウント回数(100000)とr12(通常カウント)を比較
    jne CHECK_THREE_MULTIPLE                ; 同じになるまでCHECK_THREE_MULTIPLEに戻る

; 終了処理
FIN:
                                            ; プロセス終了
    mov rax, 1                              ; 返り値を1にする
    mov rbx, 0                              ; 終了ステータスコード
    int 0x80                                ; システムコール

; 表示用カウントレジスタ(BCD)の桁上がり処理
CARRY_UP_BCD:
    or r13, r11                             ; うまく繰り上がれてない値(r13)の、繰り上がれてない部分をすべて二進数の1で埋めて
    inc r13                                 ; 1を足す → 強制的に繰り上がらせる

    or r9, 0xF                              ; r9(繰り上がれてないか確認用レジスタ)の方も
    inc r9                                  ; 繰り上がりしておく(ここで二重繰り上がりの可能性あり)

    shr r9, 4                               ; r9の値を右に1文字分(16進数ひと桁分)シフトする
    shl r11, 4                              ; `CARRY_UP_BCD`で使う対象以下の2進数での桁を1で埋めたもの(マスク)を更新する
    or r11, 0xF

    jmp CARRY_UP                            ; `CARRY_UP`に戻る

; ボケるときのフォーマットをrbpに格納
SET_BOKE_FMT:
    mov rbp, boke_message                   ; rbpにボケたときのメッセージフォーマットを格納
    jmp READY_TO_SET_FORMAT_LOOP            ; `READY_TO_SET_FORMAT_LOOP`にジャンプする(戻る)


; データセクション
section  .data                              ; データセクションの定義
    boke_message  db 0xA, "(BOKE)          "    ; boke_messageの内容
    times 20 db 0x00

    normal_message db 0xA, "                "   ; normal_messageの内容
    times 20 db 0x00

では、高速化のポイントを洗い出してみましょう。

システムコール

あかこう先輩のソースコードではシステムコールの呼び出しとして割り込み命令が使われていますが、これはレガシーな方法であり最近のx86では高速システムコールという仕組みが用意されているのでこちらを用いるのが一般的です。 高速システムコールを行う場合、システムコール番号が異なるので気をつけましょう。

; 終了処理
FIN:
                                            ; プロセス終了
    mov rax, 1                              ; 返り値を1にする
    mov rbx, 0                              ; 終了ステータスコード
    int 0x80                                ; システムコール

レガシーなシステムコール呼び出し

FIN:
    mov rdx, 1
    mov rax, 60
    syscall

高速システムコール呼び出し

linux, x86_64の場合、以下のテーブルを見ることでシステムコール番号を知ることが出来ます。

github.com

ということで、printルーチンとfinルーチンのシステムコールの呼び出し方を変えてみました。まあ、finルーチンは一度しか呼ばれないのであまり効果がないですが。

; 文字列を表示
PRINT:
    mov r14, rcx                            ; rcxの値をr14に退避

    mov rcx, rbp                            ; メッセージのアドレス
    mov rdx, 17                             ; メッセージの長さ
    mov rbx, 1                              ; 標準出力を指定
    mov rax, 4                              ; システムコール番号 (sys_write)
    int 0x80                                ; 割り込み

    mov rcx, r14                            ; r14に退避させた値をrcxに戻す

旧PRINTルーチン

PRINT:
    mov r14, rcx

    mov rsi, rbp
    mov rdx, 17
    mov rax, 1
    mov rdi, 1
    syscall

    mov rcx, r14

新PRINTルーチン

FINルーチンは上で示した通り。

これらのルーチンを入れ替えて得られる効果はどの程度のものでしょうか。timeコマンドで計ってみます。 システムの状態によって実行時間が多少前後するのでそれぞれ10回計測しました。

./a.out  0.25s user 0.61s system 99% cpu 0.859 total
./a.out  0.20s user 0.67s system 99% cpu 0.867 total
./a.out  0.26s user 0.65s system 97% cpu 0.933 total
./a.out  0.28s user 0.69s system 99% cpu 0.976 total
./a.out  0.27s user 0.65s system 99% cpu 0.926 total
./a.out  0.22s user 0.70s system 99% cpu 0.922 total
./a.out  0.22s user 0.77s system 99% cpu 0.999 total
./a.out  0.25s user 0.67s system 99% cpu 0.916 total
./a.out  0.25s user 0.69s system 99% cpu 0.937 total
./a.out  0.21s user 0.64s system 99% cpu 0.849 total

これが割込みを使ったシステムコール呼び出しの場合。

./a.out  0.14s user 0.62s system 99% cpu 0.770 total
./a.out  0.19s user 0.61s system 99% cpu 0.800 total
./a.out  0.11s user 0.72s system 99% cpu 0.838 total
./a.out  0.15s user 0.70s system 99% cpu 0.856 total
./a.out  0.17s user 0.65s system 99% cpu 0.818 total
./a.out  0.12s user 0.69s system 99% cpu 0.818 total
./a.out  0.11s user 0.73s system 97% cpu 0.869 total
./a.out  0.09s user 0.75s system 99% cpu 0.844 total
./a.out  0.14s user 0.71s system 99% cpu 0.848 total
./a.out  0.12s user 0.71s system 99% cpu 0.832 total

これが高速システムコール呼び出しを使ったシステムコール呼び出しの場合。

結果は一目瞭然ですね。

一応平均値と中央値を出しておきましょう。

打ち込むのが面倒だったので適当にプログラムを書きました。timeのフォーマットなら使えると思うので使いたい方はどうぞ。

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int main()
{
    string tmp, user, system;
    int count = 10;
    double heikin_user = 0, heikin_sys = 0;
    vector<double> _user, _sys;
    for(int i = 0; i < count; ++i) {
        cin >> tmp >> user >> tmp >> system >> tmp >> tmp >> tmp >> tmp >> tmp;
        user[4] = '\0';
        system[4] = '\0';
        _user.push_back(stod(user));
        _sys.push_back(stod(system));
    }
    for(auto hoge : _user) {
        heikin_user += hoge; 
    }
    heikin_user /= 10.0;
    for(auto hoge : _sys) {
        heikin_sys += hoge;
    }
    heikin_sys /= 10.0;
    sort(_user.begin(), _user.end());
    sort(_sys.begin(), _sys.end());
    cout << "heikin[ user: " << heikin_user << " sys: " << heikin_sys << " ]" << endl;
    cout << "chuou [ user: " << (_user[4] + _user[5]) / 2.0 
        << " sys: " << (_sys[4] + _sys[5]) / 2.0 << " ]" << endl;
    return 0;
}
割込みを使ったシステムコール呼び出し
heikin[ user: 0.241 sys: 0.674 ]
chuou [ user: 0.25 sys: 0.67 ]

高速システムコール呼び出しを使ったシステムコール呼び出し
heikin[ user: 0.134 sys: 0.689 ]
chuou [ user: 0.13 sys: 0.705 ]

おおー、いい感じです。

除算の高速化

除算は非常にコストのかかる計算です。x86_64にはdiv命令というものがありますが、これは8bit同士の計算だけでも41サイクルもかかってしまいます。

gccはこの問題をどのようにクリアしているのでしょうか。見てみましょう。

#include <stdio.h>

int main(void)
{
    for(int i = 0; i < 100; ++i) {
        int hoge = i / 3;
        printf("%d\n", hoge);
    }
    return 0;
}
$ gcc -S hoge.c
 .file "hoge.c"
    .text
    .section  .rodata
.LC0:
    .string   "%d\n"
    .text
    .globl    main
    .type main, @function
main:
.LFB0:
    .cfi_startproc
    pushq    %rbp
    .cfi_def_cfa_offset 16
    .cfi_offset 6, -16
    movq %rsp, %rbp
    .cfi_def_cfa_register 6
    subq $16, %rsp
    movl $0, -4(%rbp)
    jmp  .L2
.L3:
    movl -4(%rbp), %ecx
    movl $1431655766, %edx
    movl %ecx, %eax
    imull    %edx
    movl %ecx, %eax
    sarl $31, %eax
    subl %eax, %edx
    movl %edx, %eax
    movl %eax, -4(%rbp)
    movl -4(%rbp), %eax
    movl %eax, %esi
    leaq .LC0(%rip), %rdi
    movl $0, %eax
    call printf@PLT
    addl $1, -4(%rbp)
.L2:
    cmpl $99, -4(%rbp)
    jle  .L3
    movl $0, %eax
    leave
    .cfi_def_cfa 7, 8
    ret
    .cfi_endproc
.LFE0:
    .size main, .-main
    .ident    "GCC: (GNU) 8.2.1 20180831"
    .section  .note.GNU-stack,"",@progbits

普段intel記法を使っているのでAT&T記法は読みづらい...

注目すべきは以下の部分です。

.L3:
    movl -4(%rbp), %ecx
    movl $1431655766, %edx
    movl %ecx, %eax
    imull    %edx
    movl %ecx, %eax
    sarl $31, %eax
    subl %eax, %edx
    movl %edx, %eax
    movl %eax, -4(%rbp)
    movl -4(%rbp), %eax
    movl %eax, %esi
    leaq .LC0(%rip), %rdi
    movl $0, %eax
    call printf@PLT
    addl $1, -4(%rbp)

見ての通りdiv命令は使われずimull命令とsarl命令とsubl命令等によって除算が実現されています。 この原理の基本は以下のとおりです。

まず、前提知識としてビット演算において2nで掛けたり割ったりするのはシフト演算で実現できるため、高速に計算できます。 また、整数同士の掛け算も割り算に比べれば低コストで計算することが出来ます。

これらの知識を元に、除算を以下のように変形します。

 \displaystyle
  \frac{a}{b} = a * \frac{2^{n}}{b} * \frac{1}{2^{n}}

 bは定数なので \frac{2^{n}}{b} コンパイル時に計算することが出来ます。

ですから、実行時に計算するのは整数同士の掛け算( \frac{2^{n}}{b}は整数に丸められます)と2nでの割り算だけなのでこれで除算の高速化が出来るわけです。

基本は上の通りですがこれでは誤差が出てしまう可能性があります。ですのでまだ工夫する必要は有るのですがこちらに関しては以下のブログを参考にしていただくとして、ここでは割愛いたします。

7shi.hateblo.jp

疲れてしまったので実装をどうするかについては「世界一速いナベアツに挑む②」で紹介しようと思います。お楽しみに!

参考文献

X86アセンブラ/算術演算命令 - Wikibooks

https://www.agner.org/optimize/instruction_tables.pdf

postd.cc

https://www.recfor.net/jeans/index.php?itemid=902

Binary Division by a Constant

除算 (デジタル) - Wikipedia