Skip to content

Commit

Permalink
refactor interface to mindspore for AD
Browse files Browse the repository at this point in the history
  • Loading branch information
peter0627ustc committed Nov 24, 2023
1 parent b23453c commit 4b3f75b
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions src/backend/toMindspore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ Compile a list of graphs into a string for a python static function and output a

function to_python_str_ms(graphs::AbstractVector{<:AbstractGraph})
head = "import mindspore as ms\n@ms.jit\n"
head *= "def graphfunc(leaf):\n"
body = " graph_list = []\n"
# head *= "def graphfunc(leaf):\n"
# body = " graph_list = []\n"
body = ""
leafidx = 1
root = [id(g) for g in graphs]
inds_visitedleaf = Int[]
inds_visitednode = Int[]
rootidx = 1
for graph in graphs
for g in PostOrderDFS(graph) #leaf first search
g_id = id(g)
Expand All @@ -88,7 +90,7 @@ function to_python_str_ms(graphs::AbstractVector{<:AbstractGraph})
if isempty(subgraphs(g)) #leaf
g_id in inds_visitedleaf && continue
factor_str = factor(g) == 1 ? "" : " * $(factor(g))"
body *= " $target = ms.Tensor(leaf[$(leafidx-1)])$factor_str\n"
body *= " $target = l$(leafidx)$factor_str\n"
leafidx += 1
push!(inds_visitedleaf, g_id)
else
Expand All @@ -98,14 +100,20 @@ function to_python_str_ms(graphs::AbstractVector{<:AbstractGraph})
push!(inds_visitednode, g_id)
end
if isroot
body *= " graph_list.append($target)\n"
body *= " out$(rootidx)=$target\n"
rootidx +=1
end
end
end
tail = " return graph_list\n"
tail*= "def to_StaticGraph(leaf):\n"
tail*= " output = graphfunc(leaf)\n"
tail*= " return output"
input = ["l$(i)" for i in 1:leafidx-1]
input = join(input,",")
output = ["out$(i)" for i in 1:rootidx-1]
output = join(output,",")
head *="def graphfunc($input)\n"
tail = " return $output\n"
# tail*= "def to_StaticGraph(leaf):\n"
# tail*= " output = graphfunc(leaf)\n"
# tail*= " return output"
expr = head * body * tail
println(expr)
# return head * body * tail
Expand Down

0 comments on commit 4b3f75b

Please sign in to comment.