| from dreamcoder.grammar import * |
|
|
| epsilon = 0.001 |
|
|
|
|
| def instantiate(context, environment, tp): |
| bindings = {} |
| context, tp = tp.instantiate(context, bindings) |
| newEnvironment = {} |
| for i,ti in environment.items(): |
| context,newEnvironment[i] = ti.instantiate(context, bindings) |
| return context, newEnvironment, tp |
|
|
| def unify(*environmentsAndTypes): |
| k = Context.EMPTY |
| e = {} |
| k,t = k.makeVariable() |
| for e_,t_ in environmentsAndTypes: |
| k, e_, t_ = instantiate(k, e_, t_) |
| k = k.unify(t,t_) |
| for i,ti in e_.items(): |
| if i not in e: e[i] = ti |
| else: k = k.unify(e[i], ti) |
| return {i: ti.apply(k) for i,ti in e.items() }, t.apply(k) |
|
|
| class Union(Program): |
| def __init__(self, elements, canBeEmpty=False): |
| self.elements = frozenset(elements) |
| if not canBeEmpty: assert len(self.elements) > 1 |
| |
| @property |
| def isUnion(self): return True |
| def __eq__(self,o): |
| return isinstance(o,Union) and self.elements == o.elements |
| def __hash__(self): return hash(self.elements) |
| def __str__(self): |
| return "{%s}"%(", ".join(map(str,list(self.elements)))) |
| def show(self, isFunction): |
| return str(self) |
| def __repr__(self): return str(self) |
| def __iter__(self): return iter(self.elements) |
|
|
| class VersionTable(): |
| def __init__(self, typed=True, identity=True, factored=False): |
| self.factored = factored |
| self.identity = identity |
| self.typed = typed |
| self.debug = False |
| if self.debug: |
| print("WARNING: running version spaces in debug mode. Will be substantially slower.") |
| |
| self.expressions = [] |
| self.recursiveTable = [] |
| self.substitutionTable = {} |
| self.expression2index = {} |
| self.maximumShift = [] |
| |
| self.inhabitantTable = [] |
| |
| self.functionInhabitantTable = [] |
| self.superCache = {} |
|
|
| self.overlapTable = {} |
| |
| self.universe = self.incorporate(Primitive("U",t0,None)) |
| self.empty = self.incorporate(Union([], canBeEmpty=True)) |
|
|
| def __len__(self): return len(self.expressions) |
|
|
| def clearOverlapTable(self): |
| self.overlapTable = {} |
|
|
| def visualize(self, j): |
| from graphviz import Digraph |
| g = Digraph() |
|
|
| visited = set() |
| def walk(i): |
| if i in visited: return |
|
|
| if i == self.universe: |
| g.node(str(i), 'universe') |
| elif i == self.empty: |
| g.node(str(i), 'nil') |
| else: |
| l = self.expressions[i] |
| if l.isIndex or l.isPrimitive or l.isInvented: |
| g.node(str(i), str(l)) |
| elif l.isAbstraction: |
| g.node(str(i), "lambda") |
| walk(l.body) |
| g.edge(str(i), str(l.body)) |
| elif l.isApplication: |
| g.node(str(i), "@") |
| walk(l.f) |
| walk(l.x) |
| g.edge(str(i), str(l.f), label='f') |
| g.edge(str(i), str(l.x), label='x') |
| elif l.isUnion: |
| g.node(str(i), "U") |
| for c in l: |
| walk(c) |
| g.edge(str(i), str(c)) |
| else: |
| assert False |
| visited.add(i) |
| walk(j) |
| g.render(view=True) |
|
|
| def branchingFactor(self,j): |
| l = self.expressions[j] |
| if l.isApplication: return max(self.branchingFactor(l.f), |
| self.branchingFactor(l.x)) |
| if l.isUnion: return max([len(l.elements)] + [self.branchingFactor(e) for e in l ]) |
| if l.isAbstraction: return self.branchingFactor(l.body) |
| return 0 |
| |
| |
| def intention(self,j, isFunction=False): |
| l = self.expressions[j] |
| if l.isIndex or l.isPrimitive or l.isInvented: return l |
| if l.isAbstraction: return Abstraction(self.intention(l.body)) |
| if l.isApplication: return Application(self.intention(l.f), |
| self.intention(l.x)) |
| if l.isUnion: return Union(self.intention(e) |
| for e in l ) |
| assert False |
|
|
| def walk(self,j): |
| """yields every subversion space of j""" |
| visited = set() |
| def r(n): |
| if n in visited: return |
| visited.add(n) |
| l = self.expressions[n] |
| yield l |
| if l.isApplication: |
| yield from r(l.f) |
| yield from r(l.x) |
| if l.isAbstraction: |
| yield from r(l.body) |
| if l.isUnion: |
| for e in l: |
| yield from r(e) |
| yield from r(j) |
|
|
| |
| def incorporate(self,p): |
| |
| if p.isIndex or p.isPrimitive or p.isInvented: |
| pass |
| elif p.isAbstraction: |
| p = Abstraction(self.incorporate(p.body)) |
| elif p.isApplication: |
| p = Application(self.incorporate(p.f), |
| self.incorporate(p.x)) |
| elif p.isUnion: |
| if len(p.elements) > 0: |
| p = Union([self.incorporate(e) for e in p ]) |
| else: assert False |
|
|
| j = self._incorporate(p) |
| return j |
|
|
| def _incorporate(self,p): |
| if p in self.expression2index: return self.expression2index[p] |
|
|
| j = len(self.expressions) |
| |
| self.expressions.append(p) |
| self.expression2index[p] = j |
| self.recursiveTable.append(None) |
| self.inhabitantTable.append(None) |
| self.functionInhabitantTable.append(None) |
| |
| return j |
|
|
| def extract(self,j): |
| l = self.expressions[j] |
| if l.isAbstraction: |
| for b in self.extract(l.body): |
| yield Abstraction(b) |
| elif l.isApplication: |
| for f in self.extract(l.f): |
| for x in self.extract(l.x): |
| yield Application(f,x) |
| elif l.isIndex or l.isPrimitive or l.isInvented: |
| yield l |
| elif l.isUnion: |
| for e in l: |
| yield from self.extract(e) |
| else: assert False |
|
|
| def reachable(self, heads): |
| visited = set() |
| def visit(j): |
| if j in visited: return |
| visited.add(j) |
|
|
| l = self.expressions[j] |
| if l.isUnion: |
| for e in l: |
| visit(e) |
| elif l.isAbstraction: visit(l.body) |
| elif l.isApplication: |
| visit(l.f) |
| visit(l.x) |
|
|
| for h in heads: |
| visit(h) |
| return visited |
|
|
| def size(self,j): |
| l = self.expressions[j] |
| if l.isApplication: |
| return self.size(l.f) + self.size(l.x) |
| elif l.isAbstraction: |
| return self.size(l.body) |
| elif l.isUnion: |
| return sum(self.size(e) for e in l ) |
| else: |
| return 1 |
| |
|
|
| def union(self,elements): |
| if self.universe in elements: return self.universe |
| |
| _e = [] |
| for e in elements: |
| if self.expressions[e].isUnion: |
| for j in self.expressions[e]: |
| _e.append(j) |
| elif e != self.empty: |
| _e.append(e) |
|
|
| elements = frozenset(_e) |
| if len(elements) == 0: return self.empty |
| if len(elements) == 1: return next(iter(elements)) |
| return self._incorporate(Union(elements)) |
| def apply(self,f,x): |
| if f == self.empty: return f |
| if x == self.empty: return x |
| return self._incorporate(Application(f,x)) |
| def abstract(self,b): |
| if b == self.empty: return self.empty |
| return self._incorporate(Abstraction(b)) |
| def index(self,i): |
| return self._incorporate(Index(i)) |
|
|
| def intersection(self,a,b): |
| if a == self.empty or b == self.empty: return self.empty |
| if a == self.universe: return b |
| if b == self.universe: return a |
| if a == b: return a |
|
|
| x = self.expressions[a] |
| y = self.expressions[b] |
|
|
| if x.isAbstraction and y.isAbstraction: |
| return self.abstract(self.intersection(x.body,y.body)) |
| if x.isApplication and y.isApplication: |
| return self.apply(self.intersection(x.f,y.f), |
| self.intersection(x.x,y.x)) |
| if x.isUnion: |
| if y.isUnion: |
| return self.union([ self.intersection(x_,y_) |
| for x_ in x |
| for y_ in y ]) |
| return self.union([ self.intersection(x_, b) |
| for x_ in x ]) |
| if y.isUnion: |
| return self.union([ self.intersection(a, y_) |
| for y_ in y ]) |
| return self.empty |
|
|
| def haveOverlap(self,a,b): |
| if a == self.empty or b == self.empty: return False |
| if a == self.universe: return True |
| if b == self.universe: return True |
| if a == b: return True |
|
|
| if a in self.overlapTable: |
| if b in self.overlapTable[a]: |
| return self.overlapTable[a][b] |
| else: self.overlapTable[a] = {} |
|
|
| x = self.expressions[a] |
| y = self.expressions[b] |
|
|
| if x.isAbstraction and y.isAbstraction: |
| overlap = self.haveOverlap(x.body,y.body) |
| elif x.isApplication and y.isApplication: |
| overlap = self.haveOverlap(x.f,y.f) and \ |
| self.haveOverlap(x.x,y.x) |
| elif x.isUnion: |
| if y.isUnion: |
| overlap = any( self.haveOverlap(x_,y_) |
| for x_ in x |
| for y_ in y ) |
| overlap = any( self.haveOverlap(x_, b) |
| for x_ in x ) |
| elif y.isUnion: |
| overlap = any( self.haveOverlap(a, y_) |
| for y_ in y ) |
| else: |
| overlap = False |
| self.overlapTable[a][b] = overlap |
| return overlap |
|
|
| def minimalInhabitants(self,j): |
| """Returns (minimal size, set of singleton version spaces)""" |
| assert isinstance(j,int) |
| if self.inhabitantTable[j] is not None: return self.inhabitantTable[j] |
| e = self.expressions[j] |
| if e.isAbstraction: |
| cost, members = self.minimalInhabitants(e.body) |
| cost = cost + epsilon |
| members = {self.abstract(m) for m in members} |
| elif e.isApplication: |
| fc, fm = self.minimalFunctionInhabitants(e.f) |
| xc, xm = self.minimalInhabitants(e.x) |
| cost = fc + xc + epsilon |
| members = {self.apply(f_,x_) |
| for f_ in fm for x_ in xm } |
| elif e.isUnion: |
| children = [self.minimalInhabitants(z) |
| for z in e ] |
| cost = min(c for c,_ in children) |
| members = {zp |
| for c,z in children |
| if c == cost |
| for zp in z } |
| else: |
| assert e.isIndex or e.isInvented or e.isPrimitive |
| cost = 1 |
| members = {j} |
|
|
|
|
| |
| |
| |
| self.inhabitantTable[j] = (cost, members) |
| |
| return cost, members |
|
|
| def minimalFunctionInhabitants(self,j): |
| """Returns (minimal size, set of singleton version spaces)""" |
| assert isinstance(j,int) |
| if self.functionInhabitantTable[j] is not None: return self.functionInhabitantTable[j] |
| e = self.expressions[j] |
| if e.isAbstraction: |
| cost = POSITIVEINFINITY |
| members = set() |
| elif e.isApplication: |
| fc, fm = self.minimalFunctionInhabitants(e.f) |
| xc, xm = self.minimalInhabitants(e.x) |
| cost = fc + xc + epsilon |
| members = {self.apply(f_,x_) |
| for f_ in fm for x_ in xm } |
| elif e.isUnion: |
| children = [self.minimalFunctionInhabitants(z) |
| for z in e ] |
| cost = min(c for c,_ in children) |
| members = {zp |
| for c,z in children |
| if c == cost |
| for zp in z } |
| else: |
| assert e.isIndex or e.isInvented or e.isPrimitive |
| cost = 1 |
| members = {j} |
|
|
| |
| |
| |
| |
| self.functionInhabitantTable[j] = (cost, members) |
| return cost, members |
|
|
| def shiftFree(self,j,n,c=0): |
| if n == 0: return j |
| l = self.expressions[j] |
| if l.isUnion: |
| return self.union([ self.shiftFree(e,n,c) |
| for e in l ]) |
| if l.isApplication: |
| return self.apply(self.shiftFree(l.f,n,c), |
| self.shiftFree(l.x,n,c)) |
| if l.isAbstraction: |
| return self.abstract(self.shiftFree(l.body,n,c+1)) |
| if l.isIndex: |
| if l.i < c: return j |
| if l.i >= n + c: return self.index(l.i - n) |
| return self.empty |
| assert l.isPrimitive or l.isInvented |
| return j |
|
|
| def substitutions(self,j): |
| if self.typed: |
| for (v,_),b in self._substitutions(j,0).items(): |
| yield v,b |
| else: |
| yield from self._substitutions(j,0).items() |
|
|
| def _substitutions(self,j,n): |
| if (j,n) in self.substitutionTable: return self.substitutionTable[(j,n)] |
| |
| |
| s = self.shiftFree(j,n) |
| if self.debug: |
| assert set(self.extract(s)) == set( e.shift(-n) |
| for e in self.extract(j) |
| if all( f >= n for f in e.freeVariables() )),\ |
| "shiftFree_%d: %s"%(n,set(self.extract(s))) |
| if s == self.empty: m = {} |
| else: |
| if self.typed: |
| principalType = self.infer(s) |
| if principalType == self.bottom: |
| print(self.infer(j)) |
| print(list(self.extract(j))) |
| print(list(self.extract(s))) |
| assert False |
| m = {(s, self.infer(s)[1].canonical()): self.index(n)} |
| else: |
| m = {s: self.index(n)} |
|
|
| l = self.expressions[j] |
| if l.isPrimitive or l.isInvented: |
| m[(self.universe,t0) if self.typed else self.universe] = j |
| elif l.isIndex: |
| m[(self.universe,t0) if self.typed else self.universe] = \ |
| j if l.i < n else self.index(l.i + 1) |
| elif l.isAbstraction: |
| for v,b in self._substitutions(l.body, n + 1).items(): |
| m[v] = self.abstract(b) |
| elif l.isApplication and not self.factored: |
| newMapping = {} |
| fm = self._substitutions(l.f,n) |
| xm = self._substitutions(l.x,n) |
| for v1,f in fm.items(): |
| if self.typed: v1,nType1 = v1 |
| for v2,x in xm.items(): |
| if self.typed: v2,nType2 = v2 |
|
|
| a = self.apply(f,x) |
| |
| if self.typed: |
| if self.infer(a) == self.bottom: continue |
| try: |
| nType = canonicalUnification(nType1, nType2, |
| self.infer(a)[0].get(n,t0)) |
| except UnificationFailure: |
| continue |
| |
| v = self.intersection(v1,v2) |
| if v == self.empty: continue |
| if self.typed and self.infer(v) == self.bottom: continue |
|
|
| key = (v,nType) if self.typed else v |
| |
| if key in newMapping: |
| newMapping[key].append(a) |
| else: |
| newMapping[key] = [a] |
| for v in newMapping: |
| newMapping[v] = self.union(newMapping[v]) |
| newMapping.update(m) |
| m = newMapping |
| |
| elif l.isApplication and self.factored: |
| newMapping = {} |
| fm = self._substitutions(l.f,n) |
| xm = self._substitutions(l.x,n) |
| for v1,f in fm.items(): |
| if self.typed: v1,nType1 = v1 |
| for v2,x in xm.items(): |
| if self.typed: v2,nType2 = v2 |
| v = self.intersection(v1,v2) |
| if v == self.empty: continue |
| if v in newMapping: |
| newMapping[v] = ({f} | newMapping[v][0], |
| {x} | newMapping[v][1]) |
| else: |
| newMapping[v] = ({f},{x}) |
| for v,(fs,xs) in newMapping.items(): |
| fs = self.union(list(fs)) |
| xs = self.union(list(xs)) |
| m[v] = self.apply(fs,xs) |
| |
| elif l.isUnion: |
| newMapping = {} |
| for e in l: |
| for v,b in self._substitutions(e,n).items(): |
| if v in newMapping: |
| newMapping[v].append(b) |
| else: |
| newMapping[v] = [b] |
| for v in newMapping: |
| newMapping[v] = self.union(newMapping[v]) |
| newMapping.update(m) |
| m = newMapping |
| else: assert False |
|
|
| self.substitutionTable[(j,n)] = m |
|
|
| return m |
|
|
| def inversion(self,j): |
| i = self.union([self.apply(self.abstract(b),v) |
| for v,b in self.substitutions(j) |
| if v != self.universe]) |
| if self.debug and self.typed: |
| if not (self.infer(i) == self.infer(j)): |
| print("inversion produced space with a different type!") |
| print("the original type was",self.infer(j)) |
| print("the type of the rewritten expressions is",self.infer(i)) |
| print("the original extension was") |
| n = None |
| for e in self.extract(j): |
| print(e, e.infer()) |
| |
| assert n is None or e.betaNormalForm() == n |
| n = e.betaNormalForm() |
| print("the rewritten extension is") |
| for e in self.extract(i): |
| print(e, e.infer()) |
| |
| assert n is None or e.betaNormalForm() == n |
| assert self.infer(i) == self.infer(j) |
| assert False |
| return i |
|
|
|
|
| def recursiveInversion(self,j): |
| if self.recursiveTable[j] is not None: return self.recursiveTable[j] |
| |
| l = self.expressions[j] |
| if l.isUnion: |
| return self.union([self.recursiveInversion(e) for e in l ]) |
| |
| t = [self.apply(self.abstract(b),v) |
| for v,b in self.substitutions(j) |
| if v != self.universe and (self.identity or b != self.index(0))] |
| if self.debug and self.typed: |
| ru = self.union(t) |
| if not (self.infer(ru) == self.infer(j)): |
| print("inversion produced space with a different type!") |
| print("the original type was",self.infer(j)) |
| print("the type of the rewritten expressions is",self.infer(ru)) |
| print("the original extension was") |
| n = None |
| for e in self.extract(j): |
| print(e, e.infer()) |
| |
| assert n is None or e.betaNormalForm() == n |
| n = e.betaNormalForm() |
| print("the rewritten extension is") |
| for e in self.extract(ru): |
| print(e, e.infer()) |
| |
| assert n is None or e.betaNormalForm() == n |
| assert self.infer(ru) == self.infer(j) |
|
|
|
|
| if l.isApplication: |
| t.append(self.apply(self.recursiveInversion(l.f),l.x)) |
| t.append(self.apply(l.f,self.recursiveInversion(l.x))) |
| elif l.isAbstraction: |
| t.append(self.abstract(self.recursiveInversion(l.body))) |
|
|
| ru = self.union(t) |
| self.recursiveTable[j] = ru |
| return ru |
|
|
| def repeatedExpansion(self,j,n): |
| spaces = [j] |
| for _ in range(n): |
| spaces.append(self.recursiveInversion(spaces[-1])) |
| return spaces |
| |
| def rewriteReachable(self,heads,n): |
| vertices = self.reachable(heads) |
| spaces = {v: self.repeatedExpansion(v,n) |
| for v in vertices } |
| return spaces |
|
|
| def properVersionSpace(self, j, n): |
| return self.union(self.repeatedExpansion(j, n)) |
|
|
| def superVersionSpace(self, j, n): |
| """Construct decorated tree and then merge version spaces with subtrees via union operator""" |
| if j in self.superCache: return self.superCache[j] |
| spaces = self.rewriteReachable({j}, n) |
| def superSpace(i): |
| assert i in spaces |
| e = self.expressions[i] |
| components = [i] + spaces[i] |
| if e.isIndex or e.isPrimitive or e.isInvented: |
| pass |
| elif e.isAbstraction: |
| components.append(self.abstract(superSpace(e.body))) |
| elif e.isApplication: |
| components.append(self.apply(superSpace(e.f), superSpace(e.x))) |
| elif e.isUnion: assert False |
| else: assert False |
| |
| return self.union(components) |
| self.superCache[j] = superSpace(j) |
| return self.superCache[j] |
| |
| def loadEquivalences(self, g, spaces): |
| versionClasses = [None]*len(self.expressions) |
| def extract(j): |
| if versionClasses[j] is not None: |
| return versionClasses[j] |
| |
| l = self.expressions[j] |
| if l.isAbstraction: |
| ks = g.setOfClasses(g.abstractClass(b) |
| for b in extract(l.body)) |
| elif l.isApplication: |
| fs = extract(l.f) |
| xs = extract(l.x) |
| ks = g.setOfClasses(g.applyClass(f,x) |
| for x in xs for f in fs ) |
| elif l.isUnion: |
| ks = g.setOfClasses(e for u in l for e in extract(u)) |
| else: |
| ks = g.setOfClasses({g.incorporate(l)}) |
| versionClasses[j] = ks |
| return ks |
| |
|
|
| N = len(next(iter(spaces.values()))) |
| vertices = list(sorted(spaces.keys(), key=lambda v: self.size(v))) |
|
|
| |
| |
| typedClassesOfVertex = {v: {} for v in vertices } |
| |
| for n in range(N): |
| |
| for v in vertices: |
| expressions = list(self.extract(v)) |
| assert len(expressions) == 1 |
| expression = expressions[0] |
| k = g.incorporate(expression) |
| if k is None: continue |
| t0 = g.typeOfClass[k] |
| if t0 not in typedClassesOfVertex[v]: |
| typedClassesOfVertex[v][t0] = k |
| extracted = list(extract(spaces[v][n])) |
| for e in extracted: |
| t = g.typeOfClass[e] |
| if t in typedClassesOfVertex[v]: |
| g.makeEquivalent(typedClassesOfVertex[v][t],e) |
| else: |
| typedClassesOfVertex[v][e] = e |
|
|
| def bestInventions(self, versions, bs=25): |
| """versions: [[version index]]""" |
| """bs: beam size""" |
| """returns: list of (indices to) candidates""" |
| import gc |
| |
| def nontrivial(proposal): |
| primitives = 0 |
| collisions = 0 |
| indices = set() |
| for d, tree in proposal.walk(): |
| if tree.isPrimitive or tree.isInvented: primitives += 1 |
| elif tree.isIndex: |
| i = tree.i - d |
| if i in indices: collisions += 1 |
| indices.add(i) |
| return primitives > 1 or (primitives == 1 and collisions > 0) |
|
|
| with timing("calculated candidates from version space"): |
| candidates = [{j |
| for k in self.reachable(hs) |
| for _,js in [self.minimalInhabitants(k), self.minimalFunctionInhabitants(k)] |
| for j in js } |
| for hs in versions] |
| from collections import Counter |
| candidates = Counter(k for ks in candidates for k in ks) |
| candidates = {k for k,f in candidates.items() if f >= 2 and nontrivial(next(self.extract(k))) } |
| |
| eprint(len(candidates),"candidates from version space") |
|
|
| |
| |
| |
| |
| candidateCost = {k: len(set(next(self.extract(k)).freeVariables())) + 1 |
| for k in candidates } |
|
|
| inhabitTable = self.inhabitantTable |
| functionTable = self.functionInhabitantTable |
|
|
| class B(): |
| def __init__(self, j): |
| cost, inhabitants = inhabitTable[j] |
| functionCost, functionInhabitants = functionTable[j] |
| self.relativeCost = {inhabitant: candidateCost[inhabitant] |
| for inhabitant in inhabitants |
| if inhabitant in candidates} |
| self.relativeFunctionCost = {inhabitant: candidateCost[inhabitant] |
| |
| for inhabitant in inhabitants |
| if inhabitant in candidates} |
| self.defaultCost = cost |
| self.defaultFunctionCost = functionCost |
|
|
| @property |
| def domain(self): |
| return set(self.relativeCost.keys()) |
| @property |
| def functionDomain(self): |
| return set(self.relativeFunctionCost.keys()) |
| def restrict(self): |
| if len(self.relativeCost) > bs: |
| self.relativeCost = dict(sorted(self.relativeCost.items(), |
| key=lambda rk: rk[1])[:bs]) |
| if len(self.relativeFunctionCost) > bs: |
| self.relativeFunctionCost = dict(sorted(self.relativeFunctionCost.items(), |
| key=lambda rk: rk[1])[:bs]) |
| def getCost(self, given): |
| return self.relativeCost.get(given, self.defaultCost) |
| def getFunctionCost(self, given): |
| return self.relativeFunctionCost.get(given, self.defaultFunctionCost) |
| def relax(self, given, cost): |
| self.relativeCost[given] = min(cost, |
| self.getCost(given)) |
| def relaxFunction(self, given, cost): |
| self.relativeFunctionCost[given] = min(cost, |
| self.getFunctionCost(given)) |
|
|
| def unobject(self): |
| return {'relativeCost': self.relativeCost, 'defaultCost': self.defaultCost, |
| 'relativeFunctionCost': self.relativeFunctionCost, 'defaultFunctionCost': self.defaultFunctionCost} |
|
|
| beamTable = [None]*len(self.expressions) |
|
|
| def costs(j): |
| if beamTable[j] is not None: |
| return beamTable[j] |
|
|
| beamTable[j] = B(j) |
| |
| e = self.expressions[j] |
| if e.isIndex or e.isPrimitive or e.isInvented: |
| pass |
| elif e.isAbstraction: |
| b = costs(e.body) |
| for i,c in b.relativeCost.items(): |
| beamTable[j].relax(i, c + epsilon) |
| elif e.isApplication: |
| f = costs(e.f) |
| x = costs(e.x) |
| for i in f.functionDomain | x.domain: |
| beamTable[j].relax(i, f.getFunctionCost(i) + x.getCost(i) + epsilon) |
| beamTable[j].relaxFunction(i, f.getFunctionCost(i) + x.getCost(i) + epsilon) |
| elif e.isUnion: |
| for z in e: |
| cz = costs(z) |
| for i,c in cz.relativeCost.items(): beamTable[j].relax(i, c) |
| for i,c in cz.relativeFunctionCost.items(): beamTable[j].relaxFunction(i, c) |
| else: assert False |
|
|
| beamTable[j].restrict() |
| return beamTable[j] |
|
|
| with timing("beamed version spaces"): |
| beams = parallelMap(numberOfCPUs(), |
| lambda hs: [ costs(h).unobject() for h in hs ], |
| versions, |
| memorySensitive=True, |
| chunksize=1, |
| maxtasksperchild=1) |
|
|
| |
| beamTable = None |
| gc.collect() |
| |
| candidates = {d |
| for _bs in beams |
| for b in _bs |
| for d in b['relativeCost'].keys() } |
| def score(candidate): |
| return sum(min(min(b['relativeCost'].get(candidate, b['defaultCost']), |
| b['relativeFunctionCost'].get(candidate, b['defaultFunctionCost'])) |
| for b in _bs ) |
| for _bs in beams ) |
| candidates = sorted(candidates, key=score) |
| return candidates |
|
|
| def rewriteWithInvention(self, i, js): |
| """Rewrites list of indices in beta long form using invention""" |
| self.clearOverlapTable() |
| class RW(): |
| """rewritten cost/expression either as a function or argument""" |
| def __init__(self, f,fc,a,ac): |
| assert not (fc < ac) |
| self.f, self.fc, self.a, self.ac = f,fc,a,ac |
| |
| _i = list(self.extract(i)) |
| assert len(_i) == 1 |
| _i = _i[0] |
| |
| table = {} |
| def rewrite(j): |
| if j in table: return table[j] |
| e = self.expressions[j] |
| if self.haveOverlap(i, j): r = RW(fc=1,ac=1, |
| f=_i,a=_i) |
| elif e.isPrimitive or e.isInvented or e.isIndex: |
| r = RW(fc=1,ac=1, |
| f=e,a=e) |
| elif e.isApplication: |
| f = rewrite(e.f) |
| x = rewrite(e.x) |
| cost = f.fc + x.ac + epsilon |
| ep = Application(f.f, x.a) if cost < POSITIVEINFINITY else None |
| r = RW(fc=cost, ac=cost, |
| f=ep, a=ep) |
| elif e.isAbstraction: |
| b = rewrite(e.body) |
| cost = b.ac + epsilon |
| ep = Abstraction(b.a) if cost < POSITIVEINFINITY else None |
| r = RW(f=None, fc=POSITIVEINFINITY, |
| a=ep, ac=cost) |
| elif e.isUnion: |
| children = [rewrite(z) for z in e ] |
| f,fc = min(( (child.f, child.fc) for child in children ), |
| key=cindex(1)) |
| a,ac = min(( (child.a, child.ac) for child in children ), |
| key=cindex(1)) |
| r = RW(f=f,fc=fc, |
| a=a,ac=ac) |
| else: assert False |
| table[j] = r |
| return r |
| js = [ rewrite(j).a for j in js ] |
| self.clearOverlapTable() |
| return js |
| |
| def addInventionToGrammar(self, candidate, g0, frontiers, |
| pseudoCounts=1.): |
| candidateSource = next(self.extract(candidate)) |
| v = RewriteWithInventionVisitor(candidateSource) |
| invention = v.invention |
|
|
| rewriteMapping = list({e.program |
| for f in frontiers |
| for e in f }) |
| spaces = [self.superCache[self.incorporate(program)] |
| for program in rewriteMapping ] |
| rewriteMapping = dict(zip(rewriteMapping, |
| self.rewriteWithInvention(candidate, spaces))) |
|
|
| def tryRewrite(program, request=None): |
| rw = v.execute(rewriteMapping[program], request=request) |
| |
| |
| |
| |
| |
| |
| |
| |
| return rw or program |
|
|
| frontiers = [Frontier([FrontierEntry(program=tryRewrite(e.program, request=f.task.request), |
| logLikelihood=e.logLikelihood, |
| logPrior=0.) |
| for e in f ], |
| f.task) |
| for f in frontiers ] |
| |
| |
| |
| |
| g = Grammar.uniform([invention] + g0.primitives, continuationType=g0.continuationType).\ |
| insideOutside(frontiers, |
| pseudoCounts=pseudoCounts) |
| frontiers = [g.rescoreFrontier(f) for f in frontiers] |
| return g, frontiers |
|
|
| class CloseInventionVisitor(): |
| """normalize free variables - e.g., if $1 & $3 occur free then rename them to $0, $1 |
| then wrap in enough lambdas so that there are no free variables and finally wrap in invention""" |
| def __init__(self, p): |
| self.p = p |
| freeVariables = list(sorted(set(p.freeVariables()))) |
| self.mapping = {fv: j for j,fv in enumerate(freeVariables) } |
| def index(self, e, d): |
| if e.i - d in self.mapping: |
| return Index(self.mapping[e.i - d] + d) |
| return e |
| def abstraction(self, e, d): |
| return Abstraction(e.body.visit(self, d + 1)) |
| def application(self, e, d): |
| return Application(e.f.visit(self, d), |
| e.x.visit(self, d)) |
| def primitive(self, e, d): return e |
| def invented(self, e, d): return e |
|
|
| def execute(self): |
| normed = self.p.visit(self, 0) |
| closed = normed |
| for _ in range(len(self.mapping)): |
| closed = Abstraction(closed) |
| return Invented(closed) |
| |
| |
| class RewriteWithInventionVisitor(): |
| def __init__(self, p): |
| v = CloseInventionVisitor(p) |
| self.original = p |
| self.mapping = { j: fv for fv, j in v.mapping.items() } |
| self.invention = v.execute() |
|
|
| self.appliedInvention = self.invention |
| for j in range(len(self.mapping) - 1, -1, -1): |
| self.appliedInvention = Application(self.appliedInvention, Index(self.mapping[j])) |
| |
|
|
| def tryRewrite(self, e): |
| if e == self.original: |
| return self.appliedInvention |
| return None |
|
|
| def index(self, e): return e |
| def primitive(self, e): return e |
| def invented(self, e): return e |
| def abstraction(self, e): |
| return self.tryRewrite(e) or Abstraction(e.body.visit(self)) |
| def application(self, e): |
| return self.tryRewrite(e) or Application(e.f.visit(self), |
| e.x.visit(self)) |
| def execute(self, e, request=None): |
| try: |
| i = e.visit(self) |
| l = EtaLongVisitor(request=request).execute(i) |
| return l |
| except (UnificationFailure, EtaExpandFailure): |
| return None |
| |
|
|
|
|
|
|
| def induceGrammar_Beta(g0, frontiers, _=None, |
| pseudoCounts=1., |
| a=3, |
| aic=1., |
| topK=2, |
| topI=50, |
| structurePenalty=1., |
| CPUs=1): |
| """grammar induction using only version spaces""" |
| from dreamcoder.fragmentUtilities import primitiveSize |
| import gc |
| |
| originalFrontiers = frontiers |
| frontiers = [frontier for frontier in frontiers if not frontier.empty] |
| eprint("Inducing a grammar from", len(frontiers), "frontiers") |
|
|
| arity = a |
|
|
| def restrictFrontiers(): |
| return parallelMap(1, |
| lambda f: g0.rescoreFrontier(f).topK(topK), |
| frontiers, |
| memorySensitive=True, |
| chunksize=1, |
| maxtasksperchild=1) |
| restrictedFrontiers = restrictFrontiers() |
| |
| def objective(g, fs): |
| ll = sum(g.frontierMDL(f) for f in fs ) |
| sp = structurePenalty * sum(primitiveSize(p) for p in g.primitives) |
| return ll - sp - aic*len(g.productions) |
| |
| v = None |
| def scoreCandidate(candidate, currentFrontiers, currentGrammar): |
| try: |
| newGrammar, newFrontiers = v.addInventionToGrammar(candidate, currentGrammar, currentFrontiers, |
| pseudoCounts=pseudoCounts) |
| except InferenceFailure: |
| |
| |
| |
| |
| return NEGATIVEINFINITY |
| |
| o = objective(newGrammar, newFrontiers) |
|
|
| |
| eprint(o,'\t',newGrammar.primitives[0],':',newGrammar.primitives[0].tp) |
|
|
| |
| |
| |
| |
| |
| return o |
| |
| with timing("Estimated initial grammar production probabilities"): |
| g0 = g0.insideOutside(restrictedFrontiers, pseudoCounts) |
| oldScore = objective(g0, restrictedFrontiers) |
| eprint("Starting grammar induction score",oldScore) |
| |
| while True: |
| v = VersionTable(typed=False, identity=False) |
| with timing("constructed %d-step version spaces"%arity): |
| versions = [[v.superVersionSpace(v.incorporate(e.program), arity) for e in f] |
| for f in restrictedFrontiers ] |
| eprint("Enumerated %d distinct version spaces"%len(v.expressions)) |
| |
| |
| candidates = v.bestInventions(versions, bs=3*topI)[:topI] |
| eprint("Only considering the top %d candidates"%len(candidates)) |
|
|
| |
| v.recursiveTable = [None]*len(v) |
| v.inhabitantTable = [None]*len(v) |
| v.functionInhabitantTable = [None]*len(v) |
| v.substitutionTable = {} |
| gc.collect() |
| |
| with timing("scored the candidate inventions"): |
| scoredCandidates = parallelMap(CPUs, |
| lambda candidate: \ |
| (candidate, scoreCandidate(candidate, restrictedFrontiers, g0)), |
| candidates, |
| memorySensitive=True, |
| chunksize=1, |
| maxtasksperchild=1) |
| if len(scoredCandidates) > 0: |
| bestNew, bestScore = max(scoredCandidates, key=lambda sc: sc[1]) |
| if len(scoredCandidates) == 0 or bestScore < oldScore: |
| eprint("No improvement possible.") |
| |
| |
| |
| |
| frontiers = {f.task: f for f in frontiers} |
| frontiers = [frontiers.get(f.task, f) |
| for f in originalFrontiers] |
| return g0, frontiers |
| |
| |
| |
| |
| |
| |
| with timing("constructed versions bases for entire frontiers"): |
| for f in frontiers: |
| for e in f: |
| v.superVersionSpace(v.incorporate(e.program), arity) |
| newGrammar, newFrontiers = v.addInventionToGrammar(bestNew, g0, frontiers, |
| pseudoCounts=pseudoCounts) |
| eprint("Improved score to", bestScore, "(dS =", bestScore-oldScore, ") w/ invention",newGrammar.primitives[0],":",newGrammar.primitives[0].infer()) |
| oldScore = bestScore |
|
|
| for f in newFrontiers: |
| eprint(f.summarizeFull()) |
|
|
| g0, frontiers = newGrammar, newFrontiers |
| restrictedFrontiers = restrictFrontiers() |
|
|
|
|
| |
| |
| |
| |
| |
| |
| def testTyping(p): |
| v = VersionTable() |
| j = v.incorporate(p) |
| |
| wellTyped = set(v.extract(v.inversion(j))) |
| print(len(wellTyped)) |
| v = VersionTable(typed=False) |
| j = v.incorporate(p) |
| arbitrary = set(v.extract(v.recursiveInversion(v.recursiveInversion(v.recursiveInversion(j))))) |
| print(len(arbitrary)) |
| assert wellTyped <= arbitrary |
| assert wellTyped == {e |
| for e in arbitrary if e.wellTyped() } |
| assert all( e.wellTyped() for e in wellTyped ) |
|
|
| import sys |
| sys.exit() |
| |
| def testSharing(projection=2): |
| |
| source = "(+ 1 1)" |
| N = 4 |
| L = 6 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| distinct_programs = np.zeros((L,N)) |
| version_size = np.zeros((L,N)) |
| program_memory = np.zeros((L,N)) |
|
|
| version_size[0,1] = 24 |
| distinct_programs[0,1] = 8 |
| program_memory[0,1] = 28 |
| version_size[0,2] = 155 |
| distinct_programs[0,2] = 63 |
| program_memory[0,2] = 201 |
| version_size[0,3] = 1126 |
| distinct_programs[0,3] = 534 |
| program_memory[0,3] = 1593 |
| version_size[1,1] = 48 |
| distinct_programs[1,1] = 24 |
| program_memory[1,1] = 78 |
| version_size[1,2] = 526 |
| distinct_programs[1,2] = 457 |
| program_memory[1,2] = 1467 |
| version_size[1,3] = 6639 |
| distinct_programs[1,3] = 8146 |
| program_memory[1,3] = 26458 |
| version_size[2,1] = 74 |
| distinct_programs[2,1] = 57 |
| program_memory[2,1] = 193 |
| version_size[2,2] = 1095 |
| distinct_programs[2,2] = 2234 |
| program_memory[2,2] = 7616 |
| version_size[2,3] = 19633 |
| distinct_programs[2,3] = 74571 |
| program_memory[2,3] = 260865 |
| version_size[3,1] = 101 |
| distinct_programs[3,1] = 123 |
| program_memory[3,1] = 438 |
| version_size[3,2] = 1751 |
| distinct_programs[3,2] = 9209 |
| program_memory[3,2] = 32931 |
| version_size[3,3] = 38781 |
| distinct_programs[3,3] = 540315 |
| program_memory[3,3] = 1984171 |
| version_size[4,1] = 129 |
| distinct_programs[4,1] = 254 |
| program_memory[4,1] = 942 |
| version_size[4,2] = 2488 |
| distinct_programs[4,2] = 35011 |
| program_memory[4,2] = 129513 |
| version_size[4,3] = 63271 |
| distinct_programs[4,3] = 3477046 |
| program_memory[4,3] = 13179440 |
| version_size[5,1] = 158 |
| distinct_programs[5,1] = 514 |
| program_memory[5,1] = 1962 |
| version_size[5,2] = 3308 |
| distinct_programs[5,2] = 128319 |
| program_memory[5,2] = 485862 |
| version_size[5,3] = 93400 |
| distinct_programs[5,3] = 21042591 |
| program_memory[5,3] = 81433633 |
|
|
|
|
| |
| import matplotlib.pyplot as plot |
| from matplotlib import rcParams |
| rcParams.update({'figure.autolayout': True}) |
|
|
| if projection == 3: |
| f = plot.figure() |
| a = f.add_subplot(111, projection='3d') |
| X = np.arange(0,N) |
| Y = np.arange(0,L) |
| X,Y = np.meshgrid(X,Y) |
| Z = np.zeros((L,N)) |
| for l in range(L): |
| for n in range(N): |
| Z[l,n] = smart[(l,n)] |
|
|
| a.plot_surface(X, |
| Y, |
| np.log10(Z), |
| color='blue', |
| alpha=0.3) |
| for l in range(L): |
| for n in range(N): |
| Z[l,n] = dumb[(l,n)] |
|
|
|
|
| a.plot_surface(X, |
| Y, |
| np.log10(Z), |
| color='red', |
| alpha=0.3) |
|
|
|
|
| else: |
| plot.figure(figsize=(3.5,3)) |
| plot.tight_layout() |
| logarithmic = False |
| if logarithmic: P = plot.semilogy |
| else: P = plot.plot |
| for n in range(1, 2): |
| xs = np.array(range(L))*2 + 3 |
| P(xs, |
| [version_size[l,n] for l in range(L) ], |
| 'purple', |
| label=None if n > 1 else 'version space') |
| P(xs, |
| [program_memory[l,n] for l in range(L) ], |
| 'green', |
| label=None if n > 1 else 'no version space') |
| if n > 1: dy = 1 |
| if n == 1 and logarithmic: dy = 0.6 |
| if n == 1 and not logarithmic: dy = 1 |
| |
| |
| |
| plot.legend() |
| plot.xlabel('Size of program being refactored') |
| plot.ylabel('Size of VS (purple) or progs (green)') |
| plot.xticks(list(xs) + [xs[-1] + 2], |
| [ str(x) if j == 0 or j == L - 1 else '' |
| for j,x in enumerate(list(xs) + [xs[-1] + 2])]) |
| |
| |
| |
|
|
|
|
| plot.savefig('/tmp/vs.eps') |
| assert False |
|
|
| if __name__ == "__main__": |
| |
| from dreamcoder.domains.arithmetic.arithmeticPrimitives import * |
| from dreamcoder.domains.list.listPrimitives import * |
| from dreamcoder.fragmentGrammar import * |
| bootstrapTarget_extra() |
| McCarthyPrimitives() |
| testSharing() |
|
|
| |
| |
|
|
| |
| programs = [ |
| |
| |
| |
| |
| |
| |
| |
| ("(lambda (fix1 $0 (lambda (lambda (if (eq? 0 $0) empty (cons (- 0 $0) ($1 (+ 1 $0))))))))",None), |
| |
| |
| |
| |
| |
| |
| |
| ("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ (car $0) ($1 (cdr $0))))))))",None), |
| ("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 1 (- (car $0) ($1 (cdr $0))))))))",None), |
| ("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) (cons 0 empty) (cons (car $0) ($1 (cdr $0))))))))",None), |
| ("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) (empty? empty) (if (car $0) ($1 (cdr $0)) (eq? 1 0)))))))",None), |
| |
| |
| |
| |
| |
| |
| ] |
| programs = [(Program.parse(p),t) for p,t in programs ] |
| N=3 |
|
|
| primitives = McCarthyPrimitives() |
| |
| |
| |
| |
| g0 = Grammar.uniform(list(primitives)) |
| print(g0) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| with timing("induced DSL"): |
| induceGrammar_Beta(g0, [Frontier.dummy(p, tp=tp) for p, tp in programs], |
| CPUs=1, |
| a=N, |
| structurePenalty=0.) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|