Skip to content

Commit

Permalink
Add support of CORS
Browse files Browse the repository at this point in the history
  • Loading branch information
manifest committed Feb 15, 2016
1 parent dbb6360 commit 6868d72
Showing 1 changed file with 176 additions and 0 deletions.
176 changes: 176 additions & 0 deletions src/cowboy_req.erl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
-export([has_resp_header/2]).
-export([has_resp_body/1]).
-export([delete_resp_header/2]).
-export([set_cors_headers/2]).
-export([set_cors_preflight_headers/2]).
-export([reply/2]).
-export([reply/3]).
-export([reply/4]).
Expand All @@ -86,6 +88,30 @@
-export([lock/1]).
-export([to_list/1]).

-type cors_allowed_origins() :: [binary()] | binary().
-type cors_allowed_methods() :: [binary()].
-type cors_allowed_headers() :: [binary()].
-type cors_max_age() :: non_neg_integer() | max.
-type cors_header() :: {origins, cors_allowed_origins()}
| {exposed_headers, cors_allowed_headers()}
| {credentials, boolean()}.
-export_type([cors_header/0]).
-type cors_preflight_header() :: {origins, cors_allowed_origins()}
| {methods, cors_allowed_methods()}
| {headers, cors_allowed_headers()}
| {credentials, boolean()}
| {age, max | non_neg_integer()}.
-export_type([cors_preflight_header/0]).
-record(cors, {
origins = [] :: cors_allowed_origins(),
methods = [] :: cors_allowed_methods(),
headers = [] :: cors_allowed_headers(),
exposed_headers = [] :: cors_allowed_headers(),
credentials = false :: boolean(),
max_age :: cors_max_age()
}).
-type cors_state() :: #cors{}.

-type cookie_opts() :: cow_cookie:cookie_opts().
-export_type([cookie_opts/0]).

Expand Down Expand Up @@ -666,6 +692,137 @@ delete_resp_header(Name, Req=#http_req{resp_headers=RespHeaders}) ->
RespHeaders2 = lists:keydelete(Name, 1, RespHeaders),
Req#http_req{resp_headers=RespHeaders2}.

-spec set_cors_headers([cors_header()], Req) -> Req when Req :: req().
set_cors_headers(Input, Req) ->
try
State = cors_state(Input),
Origin =
match_cors_origin(
header(<<"origin">>, Req),
State#cors.origins),

Req2 = set_cors_allow_credentials(State#cors.credentials, Origin, Req),
set_cors_exposed_headers(State#cors.exposed_headers, Req2)
catch throw:_Reason ->
Req
end.

-spec set_cors_preflight_headers([cors_preflight_header()], Req) -> Req when Req :: req().
set_cors_preflight_headers(Input, Req) ->
try
State = cors_state(Input),
Origin =
match_cors_origin(
header(<<"origin">>, Req),
State#cors.origins),
Method =
match_cors_method(
header(<<"access-control-request-method">>, Req),
State#cors.methods),
Headers =
match_cors_headers(
header(<<"access-control-request-headers">>, Req),
State#cors.headers),

Req2 = set_cors_allow_credentials(State#cors.credentials, Origin, Req),
Req3 = set_cors_max_age(State#cors.max_age, Req2),
Req4 = set_cors_allowed_methods([Method], Req3),
set_cors_allowed_headers(Headers, Req4)
catch throw:_Reason ->
Req
end.

-spec set_cors_allow_credentials(boolean(), binary(), Req) -> Req when Req :: req().
set_cors_allow_credentials(Credentials, Origin, Req) ->
case match_cors_credentials(Credentials, Origin) of
true ->
Req2 = set_resp_header(<<"access-control-allow-origin">>, Origin, Req),
set_resp_header(<<"access-control-allow-credentials">>, <<"true">>, Req2);
_ ->
set_resp_header(<<"access-control-allow-origin">>, Origin, Req)
end.

-spec set_cors_max_age(cors_max_age(), Req) -> Req when Req :: req().
set_cors_max_age(undefined, Req) ->
Req;
set_cors_max_age(max, Req) ->
set_resp_header(<<"access-control-max-age">>, <<"1728000">>, Req);
set_cors_max_age(Val, Req) ->
set_resp_header(<<"access-control-max-age">>, integer_to_binary(Val), Req).

-spec set_cors_allowed_methods(cors_allowed_methods(), Req) -> Req when Req :: req().
%% NOTE: just to make dialyzer happy. We would need this statement
%% if we decided to return an entire list of allowed methods
%% instead of single one passed with the particular request.
%% set_cors_allowed_methods([], Req) ->
%% Req;
set_cors_allowed_methods(Val, Req) ->
set_resp_header(<<"access-control-allow-methods">>, binary_join(Val, <<$,>>), Req).

-spec set_cors_allowed_headers(cors_allowed_headers(), Req) -> Req when Req :: req().
set_cors_allowed_headers([], Req) ->
Req;
set_cors_allowed_headers(Val, Req) ->
set_resp_header(<<"access-control-allow-headers">>, binary_join(Val, <<$,>>), Req).

-spec set_cors_exposed_headers(cors_allowed_headers(), Req) -> Req when Req :: req().
set_cors_exposed_headers([], Req) ->
Req;
set_cors_exposed_headers(L, Req) ->
set_resp_header(<<"access-control-expose-headers">>, binary_join(L, <<$,>>), Req).

-spec match_cors_origin(binary() | undefined, cors_allowed_origins()) -> binary().
match_cors_origin(undefined, Origins) ->
throw({bad_origin, undefined, Origins});
match_cors_origin(Val, Val) ->
Val;
match_cors_origin(Val, <<$*>>) ->
Val;
match_cors_origin(Val, Origins) when is_list(Origins) ->
case lists:member(Val, Origins) of
true -> Val;
_ -> throw({nomatch_origin, Val, Origins})
end;
match_cors_origin(Val, Origins) ->
throw({nomatch_origin, Val, Origins}).

-spec match_cors_method(binary() | undefined, cors_allowed_methods()) -> binary().
match_cors_method(undefined, Methods) ->
throw({bad_method, undefined, Methods});
match_cors_method(Val, Methods) ->
case lists:member(Val, Methods) of
true -> Val;
_ -> throw({nomatch_method, Val, Methods})
end.

-spec match_cors_headers(binary() | undefined, cors_allowed_headers()) -> cors_allowed_headers().
match_cors_headers(undefined, _) ->
[];
match_cors_headers(Val, Headers) ->
lists:filter(
fun(Header) -> lists:member(Header, Headers) end,
binary:split(Val, [<<$,>>, <<", ">>], [global])).

-spec match_cors_credentials(boolean(), binary()) -> boolean().
match_cors_credentials(true, <<$*>>) ->
throw({bad_credentials, true, <<$*>>});
match_cors_credentials(Val, _) ->
Val.

-spec cors_state([cors_header() | cors_preflight_header()]) -> cors_state().
cors_state(Headers) ->
cors_state(Headers, #cors{}).

-spec cors_state([cors_header() | cors_preflight_header()], cors_state()) -> cors_state().
cors_state([{origins, Val}|T], State) -> cors_state(T, State#cors{origins = Val});
cors_state([{methods, Val}|T], State) -> cors_state(T, State#cors{methods = Val});
cors_state([{headers, Val}|T], State) -> cors_state(T, State#cors{headers = Val});
cors_state([{exposed_headers, Val}|T], State) -> cors_state(T, State#cors{exposed_headers = Val});
cors_state([{credentials, Val}|T], State) -> cors_state(T, State#cors{credentials = Val});
cors_state([{max_age, Val}|T], State) -> cors_state(T, State#cors{max_age = Val});
cors_state([_|T], State) -> cors_state(T, State);
cors_state([], State) -> State.

-spec reply(cowboy:http_status(), Req) -> Req when Req::req().
reply(Status, Req=#http_req{resp_body=Body}) ->
reply(Status, [], Body, Req).
Expand Down Expand Up @@ -1244,6 +1401,15 @@ filter_constraints(Tail, Map, Key, Value, Constraints) ->
filter(Tail, Map#{Key => Value2})
end.

-spec binary_join(binary() | [binary()], binary()) -> binary().
binary_join([H|T], Sep) ->
lists:foldl(
fun(Val, Acc) ->
<<Acc/binary, Sep/binary, Val/binary>>
end, H, T).
%%binary_join([], _) -> <<>>;
%%binary_join(L, _) -> L.

%% Tests.

-ifdef(TEST).
Expand Down Expand Up @@ -1298,4 +1464,14 @@ merge_headers_test_() ->
{<<"server">>,<<"Cowboy">>}]}
],
[fun() -> Res = merge_headers(L,R) end || {L, R, Res} <- Tests].

binary_join_test_() ->
Sep = <<$,>>,
Test =
[%%{<<$b>>, <<"b">>},
%%{[], <<>>},
{[<<$a>>], <<$a>>},
{[<<$a>>, <<$b>>], <<"a,b">>}],
[fun() -> Output = binary_join(Input, Sep) end || {Input, Output} <- Test].

-endif.

0 comments on commit 6868d72

Please sign in to comment.