-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvisualize.ml
138 lines (104 loc) · 4.24 KB
/
visualize.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
(* Copyright (c) 2016 Stefan Krah. BSD 2-Clause License *)
open Printf
exception InternalError of string
(*** Convert AST back to re-indented source ***)
module AstPrinter =
struct
open Ast
let indent d ch () = fprintf ch "%*s" d ""
let variadic_flag flag lst =
match flag, lst with
Nonvariadic, _ -> ""
| Variadic, [] -> "..."
| Variadic, _ -> ", ..."
let rec list_single_comma (d : int) (ch : out_channel) f = function
[a] -> f d ch a
| a :: r -> f d ch a; fprintf ch ", "; list_single_comma d ch f r
| [] -> ()
let rec list_single_push_comma (d : int) (ch : out_channel) f = function
| a :: r -> fprintf ch ", "; f d ch a; list_single_push_comma d ch f r
| [] -> ()
let rec list_multi_comma d ch f = function
[a] -> f d ch a
| a :: r -> f d ch a; fprintf ch ",\n%a" (indent d) (); list_multi_comma d ch f r
| [] -> ()
let rec datashape d ch t =
match t with
Void -> fprintf ch "void"
| Bool -> fprintf ch "bool"
| Int8 -> fprintf ch "int8"
| Int16 -> fprintf ch "int16"
| Int32 -> fprintf ch "int32"
| Int64 -> fprintf ch "int64"
| Int128 -> fprintf ch "int128"
| Uint8 -> fprintf ch "Uint8"
| Uint16 -> fprintf ch "Uint16"
| Uint32 -> fprintf ch "Uint32"
| Uint64 -> fprintf ch "Uint64"
| Uint128 -> fprintf ch "Uint128"
| Float16 -> fprintf ch "float16"
| Float32 -> fprintf ch "float32"
| Float64 -> fprintf ch "float64"
| Float128 -> fprintf ch "float128"
| Complex t -> fprintf ch "complex[%a]" (datashape d) t
| Char enc -> fprintf ch "char[%s]" (string_of_encoding enc)
| String -> fprintf ch "string"
| FixedString (size, encoding) ->
begin match encoding with
None -> fprintf ch "fixed_string[%d]" size
| Some enc -> fprintf ch "fixed_string[%d, %s]" size (string_of_encoding enc)
end
| Bytes align -> fprintf ch "bytes[align=%d]" align
| FixedBytes (size, align) -> fprintf ch "fixed_bytes[%d, align=%d]" size align
| Tuple (flag, lst) ->
let lst' = match flag with
Nonvariadic -> lst
| Variadic -> lst @ [Dtypevar "..."] (* Abuse Dtypevar for printing *)
in fprintf ch "(%a)" (datashape_list d) lst'
| Record (flag, lst) ->
let lst' = match flag with
Nonvariadic -> lst
| Variadic -> lst @ [("", Dtypevar "...")] (* Abuse Dtypevar for printing *)
in fprintf ch "{\n%a%a\n%a}"
(indent (d+2)) ()
(field_declaration_list (d+2)) lst'
(indent d) ()
| Function { fun_ret; fun_pos=Tuple (tflag, tlst); fun_kwds=Record (rflag, rlst) } ->
fprintf ch "(%a%s%a%s) -> %a"
(datashape_list d) tlst
(variadic_flag tflag tlst)
(field_declaration_list_tail d) rlst
(variadic_flag rflag rlst)
(datashape d) fun_ret
| Function _ -> raise (InternalError "unexpected function")
| Dtypevar s -> fprintf ch "%s" s
| Pointer t -> fprintf ch "pointer[%a]" (datashape d) t
| Option t -> fprintf ch "?%a" (datashape d) t
| CudaHost t -> fprintf ch "cuda_host[%a]" (datashape d) t
| CudaDevice t -> fprintf ch "cuda_host[%a]" (datashape d) t
(* dtype kinds *)
| ScalarKind -> fprintf ch "Scalar"
| CategoricalKind -> fprintf ch "Categorical"
| FixedBytesKind -> fprintf ch "FixedBytes"
| FixedStringKind -> fprintf ch "FixedString"
(* general type constructor *)
| Constr (sym, t) -> fprintf ch "%s[%a]" sym (datashape d) t
(* arrays *)
| FixedDim (size, t) -> fprintf ch "%d * %a" size (datashape d) t
| FixedDimKind t -> fprintf ch "Fixed * %a" (datashape d) t
| VarDim t -> fprintf ch "var * %a" (datashape d) t
| SymbolicDim (sym, t) -> fprintf ch "%s * %a" sym (datashape d) t
| EllipsisDim (sym, t) -> fprintf ch "%s... * %a" sym (datashape d) t
(* type kinds *)
| AnyKind -> fprintf ch "Any"
and datashape_list d ch lst = list_single_comma d ch datashape lst
and field_declaration_list d ch lst = list_multi_comma d ch field_declaration lst
and field_declaration_list_tail d ch lst = list_single_push_comma d ch field_declaration lst
and field_declaration d ch (s, t) =
match s with
"" -> fprintf ch "%a" (datashape d) t (* for printing the variadic ellipsis *)
| _ -> fprintf ch "%s: %a" s (datashape d) t
let print t =
datashape 0 stdout t;
fprintf stdout "\n";
end