二点を通る直線の式の先にある技術 第3回

リンクは記事が公開され次第随時更新します。

第3回では線形補間を取り扱います。点の数が増えても対応できるようになります。第2回で二点を通る直線をプログラムで求められるようになりました。二点を通る直線の式が求まれば、その間については予測することができます。線形補間を用いれば、二点が三点、四点…と増えていってもその間を予測することができるようになります。考え方もシンプルで、二点を通る直線の式が理解できていれば簡単です。

pythonで線形補間をとにかく使いたい!という方はinterp1dを使えばよいです。

なぜ線形補間をするか

あなたは竹の高さを測る短期のバイトをしていると想像してください。バイトの内容は一日に一度、あるタケノコの高さを測定し、記録することです。10日間のバイトで10回測定を行います。横軸が日数、縦軸が高さだとして真面目に測定すると以下のようなグラフになります。(値は適当です。)

横軸が日数、縦軸が竹の高さのグラフ

雨の日にはあまり育たなかったり、測定機器の誤差があったりして成長度合いはゆらつきますが、おおむね右肩上がりで成長しています。そんなに急激に竹の高さが変わるわけはないと考え、あなたはサボって測定を5回しか行わないことにしました。奇数の日だけ測って偶数の日はずる休みします。測定していない部分については、前の日次の日の間になるようにして提出しました。

線形補間にて偶数日の高さを予測

上の例は竹の高さを線形補間しています。こんな風に線形補間ができると、竹の高さを図る短期バイトでずる休みができるようになります。ずる休みじゃなくても、測定する機器の事情で取得できなかったデータについて線形補間を用いることがあります。

ただ、上の例では10日目のデータがとれていません。9日目まですくすく育っていた竹が10日目にイノシシに倒されているかもしれません。線形補間するとき、データが取れている範囲の値を予測することを内挿、それ以外を外挿といいます。内挿はそれなりにもっともらしい値になりますが、外挿はそうとも限りません。

線形補間と二点を通る直線の式の違いは点の数ぐらいです。二点に限らず三点以上点があるときにどうすればよいのか、次の章から説明し、今回もpythonで実装していきます。

点が三点以上あるとき

二点を通る直線を求めるプログラムは第2回で作りました。今回は点が3点以上になった場合について考えます。以下の画像のような例です。

3点のグラフ

3点の場合、2点を通る直線を引こうとすると、候補が3つあります。

直線の候補

この3点について線形補間をすると、以下のようになります。

線形補間結果

一般に、点がn個あるときに、nから2つ選ぶ場合の数で、nC2=n(n-1)/2通りの直線の引き方がありますが、線形補間ではすべての点をxの順番に並べて隣り合う点どうしに直線を引きます。上の図で緑の線は隣り合っていないために、直線はひかれません。竹の例を考えると、xは日数だったので、近い日のデータどうしを直線で結んでいます。一番近い観測から予測するので、直感ともあっていて合理的に思えます。

外挿について

線形補間というとほとんどの場合、内挿(xのある範囲)のことだけを考えますが、実運用上はその外側にあるデータについても考える必要があります。典型的な外挿の方法は以下の3つ考えられます。

  • 近くの傾きを使う
  • 固定値
  • エラー

近くの傾きを使うというのが一番直感的かと思います。上の図でうっすらと点線で描かれているようなイメージです。しかし、近くの傾きを使うと、最初や最後のデータの誤差がちょっと大きいときに予測があらぬ方向にいてしまいます。

外挿がうまくいかない例

固定値を使えばそんなに変なことにはなりませんが、予測としては微妙です。外挿はあまり信用せず、補間できなかった旨のエラーを出すなどしたほうが良いかもしれません。interp1dではfill_value引数で調整できるみたいです。

同じxについて複数のデータがある場合

このような時、単純に線形補完してしまうと、傾きが無限大になってしまう場所ができてしまいます。そのようなデータが端っこに来た場合、前述の外挿もおかしな結果になってしまいます。

同じxに複数のデータがある例

竹の例でいうと、同じ日に計測したら高さが違ってしまったという感じです。これが計測の際の誤差だとすると、その点についての値は平均値をとるのが妥当かと思います。

平均をとって線形補間

本来であれば、そのようにデータが多い点や、xが密集している点については、信頼度が高くなるように取り扱いたいですが、単純な線形補間ではそのようなことはできません。

pythonプログラム

コード量が多くなってしまったので、あまり解説はしませんが、ほとんどは描画部分です。線形補間部分はcalc_slope_intersept関数を複数回呼ぶことで実装しています。前処理として点をソートしたり、複数のデータを平均化したりしています。質問等ございましたらコメントを頂けると幸いです。

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

def plot_line(point_list, slope_list, intersept_list):
    COLORS = [
        mcolors.to_rgb("tab:blue"),
        mcolors.to_rgb("tab:orange"),
        mcolors.to_rgb("tab:green"),
        mcolors.to_rgb("tab:red"),
        mcolors.to_rgb("tab:purple"),
        mcolors.to_rgb("tab:brown"),
        mcolors.to_rgb("tab:pink"),
        mcolors.to_rgb("tab:gray"),
        mcolors.to_rgb("tab:olive"),
        mcolors.to_rgb("tab:cyan"),
    ]
    r = [-1, 11]
    rr = r + [-1, +1]

    plt.figure(figsize=(5, 5), dpi=100)

    for p in point_list:
        plt.scatter(p[0], p[1], c='darkgray')

    for i, a, b in zip(range(len(point_list) - 1), slope_list, intersept_list):
        major_color = COLORS[i]
        minor_color = COLORS[i] + (0.3,)
        if a == float('inf'):
            x1 = point_list[i][0]
            x2 = point_list[(i + 1) % len(point_list)][0]
            y1 = point_list[i + 0][1]
            y2 = point_list[(i + 1) % len(point_list)][1]
            plt.plot([x1, x2], [y1, y2], c=major_color)
            if i == 0:
                plt.plot([x1, x2], [rr[0], y1], c=minor_color,
                         linestyle='dashed')
            if i == len(point_list) - 2:
                plt.plot([x1, x2], [y2, rr[1]], c=minor_color,
                         linestyle='dashed')
        else:
            x1 = rr[0]
            x2 = point_list[i + 0][0]
            x3 = point_list[(i + 1) % len(point_list)][0]
            x4 = rr[1]
            y1 = a * x1 + b
            y2 = a * x2 + b
            y3 = a * x3 + b
            y4 = a * x4 + b
            plt.plot([x2, x3], [y2, y3], c=major_color)
            if i == 0:
                plt.plot([x1, x2], [y1, y2], c=minor_color,
                         linestyle='dashed')
            if i == len(point_list) - 2:
                plt.plot([x3, x4], [y3, y4], c=minor_color,
                         linestyle='dashed')

    plt.xlim([r[0], r[1]])
    plt.ylim([r[0], r[1]])
    plt.xticks([i for i in range(rr[0], rr[1] + 1)])
    plt.yticks([i for i in range(rr[0], rr[1] + 1)])
    plt.grid()
    plt.axhline(y=0, color='k')
    plt.axvline(x=0, color='k')
    plt.show()

def calc_slope_intersept(x1, y1, x2, y2):
    if x1 - x2 == 0:
        return (float('inf'), None)
    a = (y1 - y2) / (x1 - x2)
    b = y1 - a * x1
    return (a, b)

def average_same_point(point_list):
    d = {}
    for p in point_list:
        if p[0] in d:
            d[p[0]].append(p[1])
        else:
            d[p[0]] = [p[1]]

    new_list = []
    for (k, v) in d.items():
        new_list += [[k, sum(v) / len(v)]]

    return new_list

if __name__ == "__main__":
    x = np.random.rand(10) * 10
    y = x * 0.6 + 0.8 + np.random.normal(0, 0.15, size=x.size)
    point_list = [[xx, yy] for xx, yy in zip(x, y)]
    slope_list = []
    intersept_list = []
    point_list.sort()
    point_list = average_same_point(point_list)

    for i in range(len(point_list) - 1):
        p1 = point_list[i]
        p2 = point_list[(i + 1) % len(point_list)]
        a, b = calc_slope_intersept(p1[0], p1[1], p2[0], p2[1])
        slope_list += [a]
        intersept_list += [b]

    plot_line(point_list, slope_list, intersept_list)

最後に

今回は、線形回帰を取り扱いました。一次元の線形回帰(xが与えられたときのy)を求めるような時は単純ですが、二次元以上(xとyが与えられたときのzなど)になると、少し複雑になります。観測が格子状に与えられている場合とそうでない場合に補間の仕方が異なります。格子状であれば一次元と同じような考え方ですが、格子状でない場合はかなりややこしいです。前処理としてドロネー三角形分割をする必要があります。ドロネー三角形分割は計算コストが高く、6次元くらいになると現実的な計算時間ではなくなるそうです。これらについてはmatlabの説明がわかりやすいかもしれません。(pythonで対応するような説明を知りません。)

次回は線形回帰を微分を使わずに説明することを目標とします。

コメント

タイトルとURLをコピーしました