File size: 3,832 Bytes
094a5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
SignalCompiler.jl — Compile AI-generated Julia strategy code.
No includes. Indicators functions injected explicitly into sandbox.
"""
module SignalCompiler

using Statistics, Random

export compile_strategy, CompiledStrategy

struct CompiledStrategy
    name          :: String
    generate_fn   :: Function
    param_grid_fn :: Function
    is_valid      :: Bool
    error         :: String
end

CompiledStrategy(name::String; error::String="") =
    CompiledStrategy(name,
        (o,h,l,c,v,p)->zeros(Int,length(c)),
        ()->Dict{String,Vector{Float64}}(),
        false, error)

"""
    compile_strategy(name, code, indicator_module) -> CompiledStrategy

indicator_module is the Indicators module, passed from QuantEngine.
"""
function compile_strategy(name::String, code::String, ind_mod::Module)::CompiledStrategy
    safe = replace(replace(name," "=>"_"), r"[^\w]"=>"x")
    sandbox = Module(Symbol("S_"*safe*"_"*string(rand(UInt16),base=16)))

    # Inject all exported Indicators functions
    for fn_name in names(ind_mod; all=false)
        fn_name === :Indicators && continue
        try
            Core.eval(sandbox,
                Expr(:const, Expr(:(=), fn_name, getfield(ind_mod, fn_name))))
        catch; end
    end

    # Inject Statistics
    for sym in (:mean,:std,:var,:median,:cor,:cov)
        try Core.eval(sandbox, Expr(:const, Expr(:(=),sym,getfield(Statistics,sym)))); catch; end
    end

    # Inject safe Base
    for sym in (:length,:size,:zeros,:ones,:fill,:similar,
                :sum,:prod,:diff,:cumsum,:cumprod,
                :max,:min,:abs,:sqrt,:log,:exp,:floor,:ceil,:round,:clamp,
                :isnan,:isinf,:isfinite,:sign,
                :sort,:sortperm,:reverse,:unique,:findall,:findfirst,
                :push!,:append!,:pop!,:first,:last,:eachindex,
                :map,:filter,:any,:all,:count,
                :Int,:Int64,:Float64,:Bool,
                :Dict,:Vector,:Tuple,:Set,
                :NaN,:Inf,:pi,:true,:false,
                :println,:string,:get)
        try Core.eval(sandbox, Expr(:const, Expr(:(=),sym,getfield(Base,sym)))); catch
            try Core.eval(sandbox, Expr(:const, Expr(:(=),sym,eval(sym)))); catch; end
        end
    end

    parsed = try Meta.parseall(code)
    catch e; return CompiledStrategy(name; error="Parse: $(sprint(showerror,e))"); end

    try Core.eval(sandbox, parsed)
    catch e; return CompiledStrategy(name; error="Eval: $(sprint(showerror,e))"); end

    isdefined(sandbox,:get_param_grid) ||
        return CompiledStrategy(name; error="Missing: get_param_grid()")
    isdefined(sandbox,:generate_signals) ||
        return CompiledStrategy(name; error="Missing: generate_signals(o,h,l,c,v,params)")

    gen_fn  = getfield(sandbox, :generate_signals)
    grid_fn = getfield(sandbox, :get_param_grid)

    err = _smoke(gen_fn, grid_fn)
    err != "" && return CompiledStrategy(name; error=err)

    return CompiledStrategy(name, gen_fn, grid_fn, true, "")
end

function _smoke(gen_fn, grid_fn)::String
    try
        grid=grid_fn()
        grid isa Dict || return "get_param_grid() must return Dict"
        params=Dict{String,Float64}(k=>Float64(v isa Vector && !isempty(v) ? v[1] : 0) for (k,v) in grid)
        n=200; c=100.0.*exp.(cumsum(randn(n).*0.005))
        h=c.*(1.0.+abs.(randn(n)).*0.003); l=c.*(1.0.-abs.(randn(n)).*0.003)
        o=c.*(1.0.+randn(n).*0.001);       v=abs.(randn(n)).*1000.0.+500.0
        sigs=gen_fn(o,h,l,c,v,params)
        sigs isa Vector    || return "generate_signals must return Vector, got $(typeof(sigs))"
        length(sigs)!=n    && return "Signal length $(length(sigs)) ≠ $n"
        any(s->!(s in (-1,0,1)), sigs) && return "Values must be in {-1,0,1}"
    catch e; return "Smoke: $(sprint(showerror,e))"; end
    return ""
end

end # module SignalCompiler