Skip to content
Snippets Groups Projects
Verified Commit f3175b9c authored by Lucas Ondel Yang's avatar Lucas Ondel Yang
Browse files

merged main

parents 606a99bd 5f35cb2f
No related branches found
No related tags found
1 merge request!64Symbols
......@@ -12,3 +12,5 @@ SparseSemiringAlgebra = "f8e6a300-9d4f-438a-9da3-99be810debe9"
[compat]
Semirings = "2.0"
julia = "1.6"
SparseSemiringAlgebra = "0.6"
......@@ -4,7 +4,13 @@ module NGramFSTs
using ..TensorFSTs
export countngrams, ngramfst
export countngrams,
ngramfst,
read_arpafile,
ngram_get_score,
ngram_score_sentence
include("ngram_arpa.jl")
const BOS = 0
......
# Script for converting between OpenFst and TensorFSTs
OF=OpenFst
TF=TensorFSTs
SR=Semirings
# Include a sign to applied to the OpenFst floating point weight
_SemiringToWeightType = Dict([
(SR.LogSemiring{Float32,1}, (OF.LogWeight, -1)),
(SR.LogSemiring{Float32,-1}, (OF.LogWeight, 1)),
(SR.LogSemiring{Float64,1}, (OF.Log64Weight, -1)),
(SR.LogSemiring{Float64,-1}, (OF.Log64Weight, 1)),
(SR.LogSemiring{Float32,Inf}, (OF.TropicalWeight, -1)),
(SR.LogSemiring{Float32,-Inf}, (OF.TropicalWeight, 1)),
(SR.TropicalSemiring{Float32}, (OF.TropicalWeight, 1)),
])
# Converts from OpenFst weight to TensorFSTs semiring
_WeightToSemiringType = Dict([
(OF.LogWeight, SR.LogSemiring{Float32,-1}),
(OF.Log64Weight, SR.LogSemiring{Float64,-1}),
(OF.TropicalWeight, SR.LogSemiring{Float32,-Inf})
])
# Extracts semiring floating point value with sgn correction
function _semiring_to_weight(s::S, sgn)::AbstractFloat where S <: SR.Semiring
s.val * sgn
end
# If state s is final, return its final weight, else return zero
function final(A::TF.TensorFST{S}, s) where S <: SR.Semiring
fnd = findall(x->x==(s,),A.ω.nzcoo)
if isempty(fnd)
return zero(S)
else
return A.ω.nzval[fnd][1]
end
end
# Return the start state (the first if there are multiple)
function start(A::TF.TensorFST{S}) where S <: SR.Semiring
A.α.nzcoo[1][1]
end
# Return the states of the FST
function states(A::TF.TensorFST{S}) where S <: SR.Semiring
states_list = []
for arc in arcs(A)
push!(states_list,arc.src)
push!(states_list,arc.dest)
end
sort(unique(states_list))
end
# Return the arcs of the FST for state source q
function arcs_from_q(A::TF.TensorFST{S}, q::Integer) where S <: SR.Semiring
myarcs = []
for arc in arcs(A)
if arc.src[1]==q
push!(myarcs,arc)
end
end
return myarcs
end
numstates(A::TF.TensorFST{S}) where S <: SR.Semiring = length(states(A))
numarcs(A::TF.TensorFST{S},q) where S <: SR.Semiring = length(arcs_from_q(A,q))
function TF.TensorFST(ofst::OF.Fst{W}) where W <: OF.Weight
S = _WeightToSemiringType[W]
myarcs = []
for s in OF.states(ofst)
for a in OF.arcs(ofst, s)
push!(myarcs,
(src = convert(Int64, s), isym = convert(Int64, a.ilabel)-1, osym = convert(Int64, a.olabel)-1, weight = S(a.weight), dest = convert(Int64, a.nextstate))
)
end
end
tfst = TF.TensorFST(myarcs,
[(convert(Int64, OF.start(ofst)),) => S(1)],
[(convert(Int64,s),) => S(OF.final(ofst, s)) for s in OF.states(ofst)])
return tfst
end
function OF.VectorFst(tfst::TF.TensorFST{S}) where S <: SR.Semiring
W, sgn = _SemiringToWeightType[S]
ofst = OF.VectorFst{W}()
# We need expanded for this line only
OF.reservestates(ofst, numstates(tfst))
inv_states_dict = Dict(v[1]=>k for (k,v) in enumerate(states(tfst)))
for (k, s1) in enumerate(states(tfst))
s = s1[1]
OF.addstate!(ofst)
final_w = _semiring_to_weight(final(tfst, s), sgn)
OF.setfinal!(ofst, k, final_w)
OF.reservearcs(ofst, k, numarcs(tfst, s))
for a in arcs_from_q(tfst, s)
arc = OF.Arc(ilabel = a.isym+1,
olabel = a.osym+1,
weight = _semiring_to_weight(a.weight, sgn),
nextstate = inv_states_dict[a.dest[1]])
OF.addarc!(ofst, k, arc)
end
end
OF.setstart!(ofst, start(tfst))
return ofst
end
\ No newline at end of file
"""
read_arpafile(fd)
Read srilm arpa lm from a file descripto `fd`.
It produces dictionary of tuples like:
```
(0,0,1gram) => (logprob, backoff)
(0,1gram,2gram) => (logprob, backoff)
(1gram,2gram,3gram) => (logprob, backoff)
(1gram,2gram,3gram) => (logprob, backoff)
```
In case backoff value is not present in ARPA file, its value is `nothing`.
"""
function ngram_read_arpafile(fd)
ngram_count_in_header = Dict()
# header reading
while true
if eof(fd); break; end
if Char(peek(fd)) == '\\' && length(ngram_count_in_header) > 0
break
end # `\1-grams:` line not consumed
l = strip(readline(fd))
splitline = split(l)
nsplit = length(splitline) # n of columns in line
# Skip empty lines
if nsplit == 0; continue; end
# Consume "\data\" line
if nsplit == 1 && l == "\\data\\"; continue; end
# read ngram count from header
if splitline[1] == "ngram"
# Consume "ngram 1=200003" lines
s = match(r"ngram (\d+)=(\d+)", l) # ngram 1=6
order_ = parse(Int64, s[1]); n = parse(Int64, s[2])
ngram_count_in_header[order_] = n
continue
end
@assert false # you sholud never reach here
end
# pre-allocate ouptut vector
ngram_dicts = Vector(undef, length(ngram_count_in_header))
# data reading
ngram_order = undef
while true
if eof(fd); break; end
l = strip(readline(fd))
splitline = split(l)
nsplit = length(splitline) # n of columns in line
# Skip empty lines
if nsplit == 0; continue; end
println("$(l)")
# -----------------------
# \{N}-grams section
# -----------------------
if nsplit == 1 && occursin(r"\\.*-grams:$",l)
# initial ngram line: \{N}-grams
s = match(r"\\(\d)+-grams.*", l)
ngram_order = parse(Int64, s[1])
# read the ngrams
ngram_dict = _read_ngram_dict(fd, ngram_order)
if length(ngram_dict) != ngram_count_in_header[ngram_order]
println("""WARNING: number of ngrams does not match !
$(ngram_order)-gram,
count in header $(ngram_count_in_header[ngram_order]),
loaded ngrams $(length(ngram_dict))
"""
)
end
# no-overwrite
#@assert ! isdefined(:ngram_dicts[ngram_order]) \
# "ERROR: $(ngram_order)-grams are already in `ngram_dicts`"
# export ngrams
ngram_dicts[ngram_order] = ngram_dict
end
end
ngram_dicts
end
"""
::Dict = read_ngram_dict
Read one table of ngrams with the same order from ARPA file.
"""
function _read_ngram_dict(fd::T, ngram_order::Int64) where {T <: IO}
ngram_dict = Dict()
log_base_10_to_e = 1. / log10()
while true
# keyword follows ?
if Char(peek(fd)) == '\\'; break; end # `\2-grams:` line not consumed
# file ended ?
if eof(fd); break; end
# readline
line = strip(readline(fd))
arpa_columns = split(line)
nsplit = length(arpa_columns)
# skip empty lines
if (nsplit == 0); continue; end
# parse the line
prob = parse(Float64, arpa_columns[1])
ngram = arpa_columns[2:(1 + ngram_order)]
backoff = nothing
if (nsplit == ngram_order + 2)
backoff = parse(Float64, arpa_columns[ngram_order + 2])
end
# convert log-base: 10 -> e
prob *= log_base_10_to_e
if backoff != nothing
backoff *= log_base_10_to_e
end
ngram_dict[(ngram...,)] = (prob, backoff)
end
return ngram_dict
end
"""
Get ngram score for single ngram represented as a tuple:
`(w1, w2, w3)`
"""
function ngram_get_score(ngram_dicts::Vector, ngram::Tuple; unk::String="<UNK>")
@assert length(ngram) <= length(ngram_dicts)
# replace OOVs with "<UNK>"
ngram_ = []
for w in ngram
haskey(ngram_dicts[1], (w,)) ? push!(ngram_, w) : push!(ngram_, unk)
end
ngram = (ngram_...,)
# get the score
if haskey(ngram_dicts[length(ngram)], ngram)
# n-gram recursion
return ngram_dicts[length(ngram)][ngram][1]
else
if length(ngram) == 1
@assert false "Key error for unigram `$(ngram)`. This sholud not happen..."
end
shorter_ngram = ngram[2:end]
prob = ngram_get_score(ngram_dicts, shorter_ngram; unk=unk)
backoff = let
history = ngram[begin:(end-1)]
len_minus_1 = length(ngram) - 1
if haskey(ngram_dicts[len_minus_1], history)
backoff = ngram_dicts[len_minus_1][history][2]
else
backoff = nothing
end
# 0.0 if ngram missing or no backoff value in arpa
(backoff == nothing ? 0.0 : backoff)
end
return backoff + prob # log scores
end
end
"""
Score a sentence using the ngram model.
The `sentence` is splitted into words.
We print per-ngram scores and return final score.
"""
function ngram_score_sentence(ngram_dicts::Vector, sentence::String; unk::String="<UNK>")
ngram_order = length(ngram_dicts)
words = split(sentence)
pushfirst!(words, "<s>")
push!(words, "</s>")
scores = []
ngrams = []
for w_j in range(2,length(words))
w_i = w_j - ngram_order + 1
if w_i < 1; w_i = 1; end
ngram = words[w_i:w_j]
ngram = (ngram...,) # tuple
push!(ngrams, ngram)
push!(scores, ngram_get_score(ngram_dicts, ngram; unk=unk))
end
println("LM scores $(collect(zip(ngrams, scores)))")
return sum(scores)
end
# Example:
#
# ngram_dicts = ngram_read_arpafile("/mnt/matylda5/iveselyk/FAST-ASR_julia/lm_arpa_librispeech/3-gram.pruned.3e-7.arpa.gz")
#
# ngram_get_score(ngram_dicts, ("A", "GENTLE", "KICK"))
#
# ngram_score_sentence(ngram_dicts, "A GENTLE KICK FROM THE TALL BOY IN THE BENCH BEHIND URGED STEPHEN TO ASK A DIFFICULT QUESTION")
......@@ -34,7 +34,6 @@ export closure,
include("ops.jl")
## standalone libraries
#include("../lib/LinearFSTs.jl")
#include("../lib/NGramFSTs.jl")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment