Python: パターンマッチしてみる

なんか、趣味では最近はC言語ばっかりだったりするわけですが。

さて、関数型言語系をカジった人なら誰しも取り付かれる、モノ、それがパターンマッチ。パターンマッチが使えると、とにかく直感的にコードをかけますよね。

つーわけで、Pythonでパターンマッチを実装してみました。機能的には

  • リスト,タプルに対するパターンマッチ
  • パターン変数への束縛
  • ガード条件
  • 任意のオブジェクトに対するパターンマッチ
  • 部分パターンの束縛(Ocamlのas)

あたりを実装してみました。これだけあれば、かなり便利にコードをかけます。できるだけ、手軽に書けるように工夫してみました。こんな感じです。

変数束縛とガード。 getattr でごにょごにょしてるので簡単にかけます。

m = Match([1,2,3]) 
if m.when([1,2,m.var]) and m.var > 2:
  print m.var
# >> 3

こう使えば、Pythonに念願のswitchが!

m = Match(10)
if m(9):
  print 1
elif m(10):
  print 2
else:
  False
# >> 2

部分パターンを束縛してみます。 [1,2,m.var] 全体を all というパターン変数に束縛します。

m = Match([1,2,3]) 
if m.when([1,2,m.var]) and m.var > 5:
  False
elif m.when(m._as_all([1,2,m.var])):
  print m.all
  print m.var
else:
  raise StandardError("")
# >> [1, 2, 3]
# >> 3

任意のオブジェクトにも使えます。いわゆるレコードに対するマッチも簡単にできるということです。

class Test(object):
  def __init__(self, v1, v2):
    self.v1 = v1
    self.v2 = v2
  def __repr__(self):
    return "Test(%s, %s)"%(repr(self.v1), repr(self.v2))
m = Match([1, Test(2, 3)])
if m.when([1, m._class(Test, {"v1":2, "v2": m.v2})]):
  print m.v2
else:
  False
# >> 3

オブジェクトに対するパターンマッチは __match__ メソッドを定義するとカスタマイズできます。ここらのアイデアはScalaからいただきました。

class Test2(Test):
  def __match__(self):
    return {"value": self.v1 + self.v2}
m = Match([1, 2, Test2(3,4)])
if m.when([1,2, m._class(Test2, {"value": m.var})]):
  m.var
else:
  False
# >> 7

結構いい感じな気がします。

ダウンロード

patternmatch.py

実装のお話

ソースコードはこんな感じ。

class Match(object):
  class _var(str): pass
  class _class(object):
    def __init__(self, klass, attrs):
      self.klass= klass
      self.attrs= sorted(attrs.iteritems())
    def match(self, m, obj):
      props = getattr(obj, "__match__", lambda: obj.__dict__)()
      return issubclass(obj.__class__, self.klass) and \
            m.when(self.attrs, sorted(props.iteritems()))
  class _as(object):
    def __init__(self, name, pattern = None):
      self.name = name
      self.pattern = pattern
    def __call__(self, pattern):
      self.pattern = pattern
      return self

  def __init__(self, obj):
    self.obj = obj
    self.bind = {}

  def __getitem__(self, key):
    if not self.bind.has_key(key):
      if key.startswith("_as_"):
        return self._as(self._var(key[4:]))
      return self._var(key)
    return self.bind[key]
  __getattr__ = __getitem__
  __call__ = lambda self, *a, **k : self.when(*a, **k)

  def when(self, pattern, obj = None):
    if not obj: obj = self.obj
    if isinstance(pattern, (self._var, self._class, self._as)):
      if isinstance(obj, (list, tuple)):
        pattern = [pattern]
        obj     = [obj]

    if not isinstance(obj, (list, tuple)) and \
      not isinstance(pattern, (list, tuple)) :
      obj = [obj]
      pattern = [pattern]

    if not isinstance(obj, (list, tuple)) or  \
      not isinstance(pattern, (list, tuple)) :
      self.bind = {}
      return False

    if len(obj) != len(pattern):
      if not ((pattern[-1].__class__ == self._var) and pattern[-1].startswith("__")):
        self.bind = {}
        return False

    for i, (value, pat) in enumerate(zip(obj, pattern)):
      if value == pat:
        continue
      elif pat.__class__ == self._var and pat.startswith("__"): 
        self.bind[str(pat)] = obj[i:]
        return True
      elif pat.__class__ == self._var:
        self.bind[str(pat)] = value
      elif pat.__class__ == self._class:
        if not pat.match(self, value):
          self.bind ={}
          return False
      elif pat.__class__ == self._as:
        if not self.when(pat.pattern, value):
          self.bind ={}
          return False
        self.bind[str(pat.name)] = value
      elif isinstance(value, (list, tuple)) and isinstance(pat, (list,tuple)):
        if not self.when(pat, value):
          self.bind = {}
          return False
      else:
        self.bind = {}
        return False

    return True

まぁわりかしシンプルですね。

 

  今年も終わりが近づいてまいりました。年をとると時間がすぎるのが速いナァ・・・と痛感しております。

comments powered by Disqus