using DataFrames
using Semirings
using TensorFSTs
using CSV
using BenchmarkTools
using OpenFst
using Glob
using IterTools

include("../../TensorFSTs.jl/lib/OpenFstConvert.jl")
include("tulliocompose.jl")

pairs = [("dense","topology"),
         ("dense","charlm"),
         ("dense","random"),
         ("topology","dense")]

function compbench(compfunc, machineA, machineB, seconds)
	b = @benchmarkable $compfunc($machineA, $machineB)	    
    # tune!(b)
    t = run(b, samples=100, seconds=seconds, evals=1)
    t.times
end

function compbench2(compfunc, machineA, machineB, seconds)
    AM, Aa, Aw = tensorFST2SparseArray(machineA)
    BM, Ba, Bw = tensorFST2SparseArray(machineB)
    min_cd = min(AM.dims[4], BM.dims[3])
    AM = AM[:,:,:,1:min_cd]
    BM = BM[:,:,1:min_cd,:]
	b = @benchmarkable $compfunc($AM, $BM, $Aa, $Aw, $Ba, $Bw)
    # tune!(b)
    t = run(b, samples=100, seconds=seconds, evals=1)
    t.times
end

function renamepath(path)
    type = splitpath(path)[2]
    name = replace(splitpath(path)[3],".fst"=>"")
    type * "-" * name
end

machinezoo_path = "../../MachineZoo.jl/"

if !isdir(joinpath(machinezoo_path, "machines", "composition"))
    mkdir(joinpath(machinezoo_path, "machines", "composition"))
end

tseconds = 4
oseconds = 1

dfs = []
for path in glob(joinpath(machinezoo_path,"machines/*/fstinfo.csv"))
    df = DataFrame(CSV.File(path));
    push!(dfs, df)
end
df = vcat(dfs...)
df[!,"type"] = map(x -> splitpath(x)[2] ,df[!,"file"])

results = []
for pair in pairs
    dfx = df[df.type.==pair[1],:]
    dfy = df[df.type.==pair[2],:]
    for p in product(dfx.file, dfy.file)
        ofstA = OF.read(joinpath(machinezoo_path, p[1]))
        ofstB = OF.read(joinpath(machinezoo_path, p[2]))
        println(p)
        ofstC = OF.compose(ofstA, ofstB)
        if OF.numstates(ofstC) == 0  
            println("skipping")
            continue
        end
        p1 = renamepath(p[1])
        p2 = renamepath(p[2])
        OF.write(ofstC, joinpath(machinezoo_path, "machines", "composition", "$(p1)-x-$(p2).fst"))

        otimes = compbench(OF.compose, ofstA, ofstB, oseconds)
        print(" of: ")
        print(mean(otimes))
        ttimes = nothing
        try
            ttimes = compbench(TF.compose,TF.TensorFST(ofstA), TF.TensorFST(ofstB), tseconds)
            print(" tf: ")
            print(mean(ttimes))
        catch
            println("tf failed")
            ttimes = [Inf]
        end
        # tutimes = compbench2(tulliocompose, TF.TensorFST(ofstA), TF.TensorFST(ofstB), tseconds)
        # print(" tu: ")
        # print(mean(tutimes))
        println()

        push!(results, (fileA=p[1], fileB=p[2],
            tmin=minimum(ttimes), tmax=maximum(ttimes), tmean=mean(ttimes), tstd=std(ttimes), tlen=length(ttimes), 
            omin=minimum(otimes), omax=maximum(otimes), omean=mean(otimes), ostd=std(otimes), olen=length(otimes),
            # tumin=minimum(tutimes), tumax=maximum(tutimes), tumean=mean(tutimes), tustd=std(tutimes), tulen=length(tutimes)
 
        ))
    end
end

CSV.write("composition_benchmark.csv", DataFrame(results))