sqlparse で SQL から更新対象のカラムを抽出

sqlparse を使って SQL の UPDATE 文から更新対象のカラムを抽出してみます。

ソースコードこちら

はじめに

更新対象のカラムを抽出するにはパース結果のトークンの中から該当部分を探して値を取得します。

例えば、以下のような UPDATE 文をパースすると

1.sql
UPDATE ITEMS SET price = 1000 WHERE id = '1'

次のようなトークン構成となり、更新するカラムの部分は Comparison となります。

1.sql のパース結果
[<DML 'UPDATE' at 0x200357A3280>, 
 <Whitespace ' ' at 0x200357A32E0>, 
 <Identifier 'ITEMS' at 0x2003579E7A0>, 
 <Whitespace ' ' at 0x200357A33A0>, 
 <Keyword 'SET' at 0x200357A3400>, 
 <Whitespace ' ' at 0x200357A3460>, 
 <Comparison 'price ...' at 0x2003579E8F0>, 
 <Whitespace ' ' at 0x200357A36A0>, 
 <Where 'WHERE ...' at 0x2003579E730>]

また、次のように複数のカラムを更新する場合は IdentifierList (複数の Comparison を持っている)となります。

2.sql
update ITEMS as i set i.price = 200, i.updated_at = NOW(), i.rev = i.rev + 1 where id = '2'
2.sql のパース結果
[<DML 'update' at 0x1EA8A403160>, 
 <Whitespace ' ' at 0x1EA8A403280>, 
 <Identifier 'ITEMS ...' at 0x1EA8A3FEA40>, 
 <Whitespace ' ' at 0x1EA8A4034C0>, 
 <Keyword 'set' at 0x1EA8A403520>, <Whitespace ' ' at 0x1EA8A403580>, 
 <IdentifierList 'i.pric...' at 0x1EA8A3FEB90>, 
 <Whitespace ' ' at 0x1EA8A44C280>, 
 <Where 'where ...' at 0x1EA8A3FE810>]

更新カラムの抽出

それでは、更新対象のカラム(ついでにテーブル名も付与)を抽出してみます。

トークンの型を判定する事になるので、ここでは Python 3.10 でサポートされたパターンマッチを使っています。

テーブル名にエイリアスを使っていると get_name() ではエイリアスが返ってくるため、get_real_name() を使うようにしました。

sample1.py
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, Comparison, Token
from sqlparse.tokens import DML

import sys

sql = sys.stdin.read()

def fields_to_update(st):
    is_update = False
    table = ''

    for t in st.tokens:
        match t:
            # UPDATE 文の場合
            case Token(ttype=tt, value=v) if tt == DML and v.upper() == 'UPDATE':
                is_update = True
            # 複数カラム更新時
            case IdentifierList() if is_update:
                for c in t.tokens:
                    match c:
                        case Comparison(left=Identifier() as l) if is_update:
                            yield f"{table}.{l.get_real_name()}"
            # テーブル名の取得
            case Identifier() if is_update:
                table = t.get_real_name()
            # 単体カラム更新時
            case Comparison(left=Identifier() as l) if is_update:
                yield f"{table}.{l.get_real_name()}"

for s in sqlparse.parse(sql):
    fields = list(fields_to_update(s))
    print(fields)

上記の冗長な部分を再帰処理に変えて改良すると以下のようになりました。

sample2.py
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, Comparison, Token
from sqlparse.tokens import DML

import sys

sql = sys.stdin.read()

def fields_to_update(st):
    def process(ts, table = '', is_update = False):
        for t in ts.tokens:
            match t:
                case Token(ttype=tt, value=v) if tt == DML and v.upper() == 'UPDATE':
                    is_update = True
                case IdentifierList() if is_update:
                    yield from process(t, table, is_update)
                case Identifier() if is_update:
                    table = t.get_real_name()
                case Comparison(left=Identifier() as l) if is_update:
                    yield f"{table}.{l.get_real_name()}"
    
    yield from process(st)

for s in sqlparse.parse(sql):
    fields = list(fields_to_update(s))
    print(fields)

動作確認

以下の SQL を使って動作確認してみます。

3.sql
UPDATE ITEMS SET price = 1000 WHERE id = '1';
SELECT * FROM sample.ITEMS WHERE price > 1000;

update
  sample.ITEMS as i 
set
  i.price = 200, 
  i.updated_at = NOW(), 
  i.rev = i.rev + 1
where
  id = '2';

delete from ITEMS where price <= 0;

実行結果は以下の通りです。

sample1.py 実行結果
$ python sample1.py < 3.sql
['ITEMS.price']
[]
['ITEMS.price', 'ITEMS.updated_at', 'ITEMS.rev']
[]
sample2.py 実行結果
$ python sample2.py < 3.sql
['ITEMS.price']
[]
['ITEMS.price', 'ITEMS.updated_at', 'ITEMS.rev']
[]