From 6868d729331871b261ee0a1f3fbb9ccd8184abde Mon Sep 17 00:00:00 2001 From: Andrei Nesterov Date: Mon, 15 Feb 2016 23:42:25 +0300 Subject: [PATCH] Add support of CORS --- src/cowboy_req.erl | 176 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/src/cowboy_req.erl b/src/cowboy_req.erl index 8f0a04b52..b1eecef5c 100644 --- a/src/cowboy_req.erl +++ b/src/cowboy_req.erl @@ -67,6 +67,8 @@ -export([has_resp_header/2]). -export([has_resp_body/1]). -export([delete_resp_header/2]). +-export([set_cors_headers/2]). +-export([set_cors_preflight_headers/2]). -export([reply/2]). -export([reply/3]). -export([reply/4]). @@ -86,6 +88,30 @@ -export([lock/1]). -export([to_list/1]). +-type cors_allowed_origins() :: [binary()] | binary(). +-type cors_allowed_methods() :: [binary()]. +-type cors_allowed_headers() :: [binary()]. +-type cors_max_age() :: non_neg_integer() | max. +-type cors_header() :: {origins, cors_allowed_origins()} + | {exposed_headers, cors_allowed_headers()} + | {credentials, boolean()}. +-export_type([cors_header/0]). +-type cors_preflight_header() :: {origins, cors_allowed_origins()} + | {methods, cors_allowed_methods()} + | {headers, cors_allowed_headers()} + | {credentials, boolean()} + | {age, max | non_neg_integer()}. +-export_type([cors_preflight_header/0]). +-record(cors, { + origins = [] :: cors_allowed_origins(), + methods = [] :: cors_allowed_methods(), + headers = [] :: cors_allowed_headers(), + exposed_headers = [] :: cors_allowed_headers(), + credentials = false :: boolean(), + max_age :: cors_max_age() +}). +-type cors_state() :: #cors{}. + -type cookie_opts() :: cow_cookie:cookie_opts(). -export_type([cookie_opts/0]). @@ -666,6 +692,137 @@ delete_resp_header(Name, Req=#http_req{resp_headers=RespHeaders}) -> RespHeaders2 = lists:keydelete(Name, 1, RespHeaders), Req#http_req{resp_headers=RespHeaders2}. +-spec set_cors_headers([cors_header()], Req) -> Req when Req :: req(). +set_cors_headers(Input, Req) -> + try + State = cors_state(Input), + Origin = + match_cors_origin( + header(<<"origin">>, Req), + State#cors.origins), + + Req2 = set_cors_allow_credentials(State#cors.credentials, Origin, Req), + set_cors_exposed_headers(State#cors.exposed_headers, Req2) + catch throw:_Reason -> + Req + end. + +-spec set_cors_preflight_headers([cors_preflight_header()], Req) -> Req when Req :: req(). +set_cors_preflight_headers(Input, Req) -> + try + State = cors_state(Input), + Origin = + match_cors_origin( + header(<<"origin">>, Req), + State#cors.origins), + Method = + match_cors_method( + header(<<"access-control-request-method">>, Req), + State#cors.methods), + Headers = + match_cors_headers( + header(<<"access-control-request-headers">>, Req), + State#cors.headers), + + Req2 = set_cors_allow_credentials(State#cors.credentials, Origin, Req), + Req3 = set_cors_max_age(State#cors.max_age, Req2), + Req4 = set_cors_allowed_methods([Method], Req3), + set_cors_allowed_headers(Headers, Req4) + catch throw:_Reason -> + Req + end. + +-spec set_cors_allow_credentials(boolean(), binary(), Req) -> Req when Req :: req(). +set_cors_allow_credentials(Credentials, Origin, Req) -> + case match_cors_credentials(Credentials, Origin) of + true -> + Req2 = set_resp_header(<<"access-control-allow-origin">>, Origin, Req), + set_resp_header(<<"access-control-allow-credentials">>, <<"true">>, Req2); + _ -> + set_resp_header(<<"access-control-allow-origin">>, Origin, Req) + end. + +-spec set_cors_max_age(cors_max_age(), Req) -> Req when Req :: req(). +set_cors_max_age(undefined, Req) -> + Req; +set_cors_max_age(max, Req) -> + set_resp_header(<<"access-control-max-age">>, <<"1728000">>, Req); +set_cors_max_age(Val, Req) -> + set_resp_header(<<"access-control-max-age">>, integer_to_binary(Val), Req). + +-spec set_cors_allowed_methods(cors_allowed_methods(), Req) -> Req when Req :: req(). +%% NOTE: just to make dialyzer happy. We would need this statement +%% if we decided to return an entire list of allowed methods +%% instead of single one passed with the particular request. +%% set_cors_allowed_methods([], Req) -> +%% Req; +set_cors_allowed_methods(Val, Req) -> + set_resp_header(<<"access-control-allow-methods">>, binary_join(Val, <<$,>>), Req). + +-spec set_cors_allowed_headers(cors_allowed_headers(), Req) -> Req when Req :: req(). +set_cors_allowed_headers([], Req) -> + Req; +set_cors_allowed_headers(Val, Req) -> + set_resp_header(<<"access-control-allow-headers">>, binary_join(Val, <<$,>>), Req). + +-spec set_cors_exposed_headers(cors_allowed_headers(), Req) -> Req when Req :: req(). +set_cors_exposed_headers([], Req) -> + Req; +set_cors_exposed_headers(L, Req) -> + set_resp_header(<<"access-control-expose-headers">>, binary_join(L, <<$,>>), Req). + +-spec match_cors_origin(binary() | undefined, cors_allowed_origins()) -> binary(). +match_cors_origin(undefined, Origins) -> + throw({bad_origin, undefined, Origins}); +match_cors_origin(Val, Val) -> + Val; +match_cors_origin(Val, <<$*>>) -> + Val; +match_cors_origin(Val, Origins) when is_list(Origins) -> + case lists:member(Val, Origins) of + true -> Val; + _ -> throw({nomatch_origin, Val, Origins}) + end; +match_cors_origin(Val, Origins) -> + throw({nomatch_origin, Val, Origins}). + +-spec match_cors_method(binary() | undefined, cors_allowed_methods()) -> binary(). +match_cors_method(undefined, Methods) -> + throw({bad_method, undefined, Methods}); +match_cors_method(Val, Methods) -> + case lists:member(Val, Methods) of + true -> Val; + _ -> throw({nomatch_method, Val, Methods}) + end. + +-spec match_cors_headers(binary() | undefined, cors_allowed_headers()) -> cors_allowed_headers(). +match_cors_headers(undefined, _) -> + []; +match_cors_headers(Val, Headers) -> + lists:filter( + fun(Header) -> lists:member(Header, Headers) end, + binary:split(Val, [<<$,>>, <<", ">>], [global])). + +-spec match_cors_credentials(boolean(), binary()) -> boolean(). +match_cors_credentials(true, <<$*>>) -> + throw({bad_credentials, true, <<$*>>}); +match_cors_credentials(Val, _) -> + Val. + +-spec cors_state([cors_header() | cors_preflight_header()]) -> cors_state(). +cors_state(Headers) -> + cors_state(Headers, #cors{}). + +-spec cors_state([cors_header() | cors_preflight_header()], cors_state()) -> cors_state(). +cors_state([{origins, Val}|T], State) -> cors_state(T, State#cors{origins = Val}); +cors_state([{methods, Val}|T], State) -> cors_state(T, State#cors{methods = Val}); +cors_state([{headers, Val}|T], State) -> cors_state(T, State#cors{headers = Val}); +cors_state([{exposed_headers, Val}|T], State) -> cors_state(T, State#cors{exposed_headers = Val}); +cors_state([{credentials, Val}|T], State) -> cors_state(T, State#cors{credentials = Val}); +cors_state([{max_age, Val}|T], State) -> cors_state(T, State#cors{max_age = Val}); +cors_state([_|T], State) -> cors_state(T, State); +cors_state([], State) -> State. + -spec reply(cowboy:http_status(), Req) -> Req when Req::req(). reply(Status, Req=#http_req{resp_body=Body}) -> reply(Status, [], Body, Req). @@ -1244,6 +1401,15 @@ filter_constraints(Tail, Map, Key, Value, Constraints) -> filter(Tail, Map#{Key => Value2}) end. +-spec binary_join(binary() | [binary()], binary()) -> binary(). +binary_join([H|T], Sep) -> + lists:foldl( + fun(Val, Acc) -> + <> + end, H, T). +%%binary_join([], _) -> <<>>; +%%binary_join(L, _) -> L. + %% Tests. -ifdef(TEST). @@ -1298,4 +1464,14 @@ merge_headers_test_() -> {<<"server">>,<<"Cowboy">>}]} ], [fun() -> Res = merge_headers(L,R) end || {L, R, Res} <- Tests]. + +binary_join_test_() -> + Sep = <<$,>>, + Test = + [%%{<<$b>>, <<"b">>}, + %%{[], <<>>}, + {[<<$a>>], <<$a>>}, + {[<<$a>>, <<$b>>], <<"a,b">>}], + [fun() -> Output = binary_join(Input, Sep) end || {Input, Output} <- Test]. + -endif.