diff --git a/vowpalwabbit/gen_cs_example.cc b/vowpalwabbit/gen_cs_example.cc index 8fe85b46696..97eb548269c 100644 --- a/vowpalwabbit/gen_cs_example.cc +++ b/vowpalwabbit/gen_cs_example.cc @@ -115,26 +115,27 @@ void gen_cs_example_ips(cb_to_cs& c, CB::label& ld, COST_SENSITIVE::label& cs_ld //this implements the inverse propensity score method, where cost are importance weighted by the probability of the chosen action //generate cost-sensitive example cs_ld.costs.clear(); - if (ld.costs.size() == 1 || ld.costs.size() == 0) //this is a typical example where we can perform all actions - { - //in this case generate cost-sensitive example with all actions - for (uint32_t i = 1; i <= c.num_actions; i++) + if (ld.costs.size() == 0 || (ld.costs.size() == 1 && ld.costs[0].cost != FLT_MAX)) + //this is a typical example where we can perform all actions { - COST_SENSITIVE::wclass wc = {0.,i,0.,0.}; - if (c.known_cost != nullptr && i == c.known_cost->action) - { - wc.x = c.known_cost->cost / safe_probability(c.known_cost->probability); //use importance weighted cost for observed action, 0 otherwise - //ips can be thought as the doubly robust method with a fixed regressor that predicts 0 costs for everything - //update the loss of this regressor - c.nb_ex_regressors++; - c.avg_loss_regressors += (1.0f / c.nb_ex_regressors)*((c.known_cost->cost)*(c.known_cost->cost) - c.avg_loss_regressors); - c.last_pred_reg = 0; - c.last_correct_cost = c.known_cost->cost; - } - - cs_ld.costs.push_back(wc); + //in this case generate cost-sensitive example with all actions + for (uint32_t i = 1; i <= c.num_actions; i++) + { + COST_SENSITIVE::wclass wc = {0.,i,0.,0.}; + if (c.known_cost != nullptr && i == c.known_cost->action) + { + wc.x = c.known_cost->cost / safe_probability(c.known_cost->probability); //use importance weighted cost for observed action, 0 otherwise + //ips can be thought as the doubly robust method with a fixed regressor that predicts 0 costs for everything + //update the loss of this regressor + c.nb_ex_regressors++; + c.avg_loss_regressors += (1.0f / c.nb_ex_regressors)*((c.known_cost->cost)*(c.known_cost->cost) - c.avg_loss_regressors); + c.last_pred_reg = 0; + c.last_correct_cost = c.known_cost->cost; + } + + cs_ld.costs.push_back(wc); + } } - } else //this is an example where we can only perform a subset of the actions { //in this case generate cost-sensitive example with only allowed actions diff --git a/vowpalwabbit/gen_cs_example.h b/vowpalwabbit/gen_cs_example.h index d9e9865ecdc..f60bf9821cf 100644 --- a/vowpalwabbit/gen_cs_example.h +++ b/vowpalwabbit/gen_cs_example.h @@ -57,7 +57,7 @@ void gen_cs_example_dm(cb_to_cs& c, example& ec, COST_SENSITIVE::label& cs_ld) cs_ld.costs.clear(); c.pred_scores.costs.clear(); - if (ld.costs.size() == 1 || ld.costs.size() == 0) //this is a typical example where we can perform all actions + if (ld.costs.size() == 0 || (ld.costs.size() == 1 && ld.costs[0].cost != FLT_MAX) ) //this is a typical example where we can perform all actions { //in this case generate cost-sensitive example with all actions for (uint32_t i = 1; i <= c.num_actions; i++) { COST_SENSITIVE::wclass wc = {0., i, 0., 0.}; @@ -139,7 +139,8 @@ void gen_cs_example_dr(cb_to_cs& c, example& ec, CB::label& ld, COST_SENSITIVE:: COST_SENSITIVE::wclass temp = { FLT_MAX, i, 0., 0. }; cs_ld.costs.push_back(temp); } - else if (ld.costs.size() == 1 || ld.costs.size() == 0) //this is a typical example where we can perform all actions + else if (ld.costs.size() == 0 || (ld.costs.size() == 1 && ld.costs[0].cost != FLT_MAX) ) + //this is a typical example where we can perform all actions //in this case generate cost-sensitive example with all actions for (uint32_t i = 1; i <= c.num_actions; i++) gen_cs_label(c, ec, cs_ld, i); diff --git a/vowpalwabbit/parser.cc b/vowpalwabbit/parser.cc index 1b2a9b06834..eb01605f920 100644 --- a/vowpalwabbit/parser.cc +++ b/vowpalwabbit/parser.cc @@ -190,7 +190,7 @@ uint32_t cache_numbits(io_buf* buf, int filepointer) version_struct v_tmp(t.begin()); if ( v_tmp != version ) { - cout << "cache has possibly incompatible version, rebuilding" << endl; + // cout << "cache has possibly incompatible version, rebuilding" << endl; t.delete_v(); return 0; }