朋友偶然 pass 一篇 Deep Neural Decision Trees[2018] 的 paper 過來,裡面用了一個可微分的方式去表示一個 decision tree 所以就可以用 Deep Learning 的手法去訓練一個 decision tree。看 paper 的過程中,看它的實作碰到需要 Kronecker product,用了一個之前看過但其實不熟悉的 torch.einsum
的 api。如果有念過基本的 Tensor Calculus 對於 Einstein Summation 應該不會太陌生,但我發現自己好像從來沒坐下來好好把這方面的思路想清楚過。由於 torch.einsum
其實在實作各種 tensor 運算時真的很好用,所以值得我好好寫一篇筆記記錄一下,那廢話就到此為止,來看看 Einstein Summation 在說什麼吧。
Einstein Summation Notation 是什麼
根據 wiki 上的定義,Einstein Notation 是這樣的
… is a notational convention that implies summation over a set of indexed terms in a formula, thus achieving notational brevity.
也就是說,Einstein Notation 就是個簡化方程式中的加總項,尤其是帶有 index 的加總。
或許你會想“不過就是個符號嘛,有什麼好特別的?”,其實我自己在學習一些偏理論的東西的時候,有發現其實一個好的符號系統會幫助你推導跟思考的容易程度,進而改進學習效果,另外一個題外話就是,有心理學研究指出一個人的色感跟其所使用的語言有關 (我不是心理學專業,也理解這個是還在辯論中的題目,但確實是有研究支持這個說法,譬如這篇),所以說其實怎麼表示一個問題會大大影響你的思考方式。我自己是認為 Einstein Notation 可以有下列好處:
- 知道每個做為參數的 tensor 的每一個維度應該要怎麼對齊,也就是說第幾個維度的 element 應該跟另一個維度對在一起做加總
- 每個 tensor 的維度是多少
- 產出的 tensor 維度是多少
我自己是覺得 3 在對實作 deep learning model 時尤其有用,因為我們常常就是 print 上個 layer 產生的 tensor 的維度去看看到底跟接下來的 layer 的 input 所需的維度有沒有對上,那寫成 Einstein Notation 後這些都很明顯。
簡單的來說,Einstein Notation 是由 index ( i
、j
、k
等等)、逗號 (,
) 跟箭頭 (->
)組成,這邊我們來看一些例子。
- Inner Product:
i,i->
- Outer Product:
i,k->ik
- Transpose:
ik->ki
- Matrix Multiplication:
ik,kj->ij
所以簡單的來說,有以下規則:
- 在
->
左手邊表示的是做為 input 的 tensor 有幾個,以,
隔開,每個,
間的 index 個數代表相對應的 tensor 維度是多少,譬如說ik
代表說對應位置的 tensor 應該要是一個 rank 為 2 的 tensor 。 - 再來就是 index 的位置,以
ik,kj->ij
為例,因為第一個 tensor 的第二個維度被標記為k
而第二個 tensor 則是第一個維度標記為k
,代表該位置的元素會被乘在一起。 ->
左邊有但右邊沒有的 index 會被加總。例如上面矩陣乘法的例子,在->
左手邊有k
但右手邊沒有,代表沿著k
維度的元素乘起來的結果會被加總。
以目前有的例子來說,或許會覺得 Einstein Notation 並沒有太大的用途,這邊我們考慮下面這個例子:
也就是說我在 A
的 row 跟 B
的 column 做 element-wise 相乘後,不是加總而是取最大值,善用 np.einsum
之後可以寫成:
稍微驗算一下:
我目前除了 for-loop 外沒有想到不用 einsum
的實作要怎麼做 (歡迎回覆你的做法來打我臉 XD),相比之下不難看出 einsum
的實作是更簡潔好讀的,當然前提是要看得懂 einsum
在說什麼就是了。
在網路上隨意搜搜之後,發現這篇文章有提供一個列表,有很多 einsum
的範例,可以做為練習看看根據上面列出來的規則你可不可以知道那些例子在算什麼。
另外該文章也有提到一個我之前沒注意到的重點,相較於一些高階 api 會去做 type/shape promotion, einsum
並不會這麼做,所以有機會引入一些 overflow/underflow 的風險,譬如說一個 array 是 int8
,用 einsum
加總後可能就會 overflow 。另外也可以省略 ->
跟右手邊的 output indices,但老實說看了一些範例之後,我覺得這種寫法除了不是很 explicit 之外,也有可能忽略 einsum
隱含的一些 index 排序規則造成預期外行為,所以我自己是不太喜歡這種寫法就是了。
好啦,關於 einsum
的筆記大概就這樣啦,之後應該會寫一些 Neural Decision Tree 的東西,因為這方面的模型實作確實是蠻有趣的。
Happy Python Programming!