なんか、趣味では最近はC言語ばっかりだったりするわけですが。
さて、関数型言語系をカジった人なら誰しも取り付かれる、モノ、それがパターンマッチ。パターンマッチが使えると、とにかく直感的にコードをかけますよね。
つーわけで、Pythonでパターンマッチを実装してみました。機能的には
- リスト,タプルに対するパターンマッチ
- パターン変数への束縛
- ガード条件
- 任意のオブジェクトに対するパターンマッチ
- 部分パターンの束縛(Ocamlのas)
あたりを実装してみました。これだけあれば、かなり便利にコードをかけます。できるだけ、手軽に書けるように工夫してみました。こんな感じです。
変数束縛とガード。 getattr
でごにょごにょしてるので簡単にかけます。
1m = Match([1,2,3])
2if m.when([1,2,m.var]) and m.var > 2:
3 print m.var
4# >> 3
こう使えば、Pythonに念願のswitchが!
1m = Match(10)
2if m(9):
3 print 1
4elif m(10):
5 print 2
6else:
7 False
8# >> 2
部分パターンを束縛してみます。 [1,2,m.var]
全体を all
というパターン変数に束縛します。
1m = Match([1,2,3])
2if m.when([1,2,m.var]) and m.var > 5:
3 False
4elif m.when(m._as_all([1,2,m.var])):
5 print m.all
6 print m.var
7else:
8 raise StandardError("")
9# >> [1, 2, 3]
10# >> 3
任意のオブジェクトにも使えます。いわゆるレコードに対するマッチも簡単にできるということです。
1class Test(object):
2 def __init__(self, v1, v2):
3 self.v1 = v1
4 self.v2 = v2
5 def __repr__(self):
6 return "Test(%s, %s)"%(repr(self.v1), repr(self.v2))
7m = Match([1, Test(2, 3)])
8if m.when([1, m._class(Test, {"v1":2, "v2": m.v2})]):
9 print m.v2
10else:
11 False
12# >> 3
オブジェクトに対するパターンマッチは __match__
メソッドを定義するとカスタマイズできます。ここらのアイデアはScalaからいただきました。
1class Test2(Test):
2 def __match__(self):
3 return {"value": self.v1 + self.v2}
4m = Match([1, 2, Test2(3,4)])
5if m.when([1,2, m._class(Test2, {"value": m.var})]):
6 m.var
7else:
8 False
9# >> 7
結構いい感じな気がします。
ダウンロード
実装のお話
ソースコードはこんな感じ。
1class Match(object):
2 class _var(str): pass
3 class _class(object):
4 def __init__(self, klass, attrs):
5 self.klass= klass
6 self.attrs= sorted(attrs.iteritems())
7 def match(self, m, obj):
8 props = getattr(obj, "__match__", lambda: obj.__dict__)()
9 return issubclass(obj.__class__, self.klass) and \
10 m.when(self.attrs, sorted(props.iteritems()))
11 class _as(object):
12 def __init__(self, name, pattern = None):
13 self.name = name
14 self.pattern = pattern
15 def __call__(self, pattern):
16 self.pattern = pattern
17 return self
18
19 def __init__(self, obj):
20 self.obj = obj
21 self.bind = {}
22
23 def __getitem__(self, key):
24 if not self.bind.has_key(key):
25 if key.startswith("_as_"):
26 return self._as(self._var(key[4:]))
27 return self._var(key)
28 return self.bind[key]
29 __getattr__ = __getitem__
30 __call__ = lambda self, *a, **k : self.when(*a, **k)
31
32 def when(self, pattern, obj = None):
33 if not obj: obj = self.obj
34 if isinstance(pattern, (self._var, self._class, self._as)):
35 if isinstance(obj, (list, tuple)):
36 pattern = [pattern]
37 obj = [obj]
38
39 if not isinstance(obj, (list, tuple)) and \
40 not isinstance(pattern, (list, tuple)) :
41 obj = [obj]
42 pattern = [pattern]
43
44 if not isinstance(obj, (list, tuple)) or \
45 not isinstance(pattern, (list, tuple)) :
46 self.bind = {}
47 return False
48
49 if len(obj) != len(pattern):
50 if not ((pattern[-1].__class__ == self._var) and pattern[-1].startswith("__")):
51 self.bind = {}
52 return False
53
54 for i, (value, pat) in enumerate(zip(obj, pattern)):
55 if value == pat:
56 continue
57 elif pat.__class__ == self._var and pat.startswith("__"):
58 self.bind[str(pat)] = obj[i:]
59 return True
60 elif pat.__class__ == self._var:
61 self.bind[str(pat)] = value
62 elif pat.__class__ == self._class:
63 if not pat.match(self, value):
64 self.bind ={}
65 return False
66 elif pat.__class__ == self._as:
67 if not self.when(pat.pattern, value):
68 self.bind ={}
69 return False
70 self.bind[str(pat.name)] = value
71 elif isinstance(value, (list, tuple)) and isinstance(pat, (list,tuple)):
72 if not self.when(pat, value):
73 self.bind = {}
74 return False
75 else:
76 self.bind = {}
77 return False
78
79 return True
まぁわりかしシンプルですね。
今年も終わりが近づいてまいりました。年をとると時間がすぎるのが速いナァ・・・と痛感しております。