Skip to content

Commit

Permalink
Merge branch 'cb_dm_semantics' of github.com:JohnLangford/vowpal_wabb…
Browse files Browse the repository at this point in the history
…it into cb_dm_semantics
  • Loading branch information
JohnLangford committed Jul 5, 2018
2 parents a447c2a + 4712533 commit 9d95c81
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 12 deletions.
5 changes: 4 additions & 1 deletion python/examples/covington.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def _run(self, sentence):
output = [-1 for i in range(N)]
for n in range(N):
# make LDF examples
examples = [ lambda: self.makeExample(sentence,n=n,m=m) for m in range(-1,N) if n != m ]
examples = []
for m in range(-1, N):
if n != m:
examples.append(self.makeExample(sentence=sentence, n=n, m=m))

# truth
parN = sentence[n][1]
Expand Down
7 changes: 2 additions & 5 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,8 @@ def run(self):

class Sdist(_sdist):
def run(self):
# try to run prep if needed
try:
prep()
except:
pass
# run prep if needed
prep()
_sdist.run(self)


Expand Down
1 change: 1 addition & 0 deletions python/vowpalwabbit/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
# -*- coding: utf-8 -*-
"""Python interfaces for VW"""
18 changes: 16 additions & 2 deletions python/vowpalwabbit/pyvw.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# -*- coding: utf-8 -*-
"""Python binding for pylibvw class"""

from __future__ import division
import pylibvw


class SearchTask():
"""Search task class"""
def __init__(self, vw, sch, num_actions):
self.vw = vw
self.sch = sch
Expand All @@ -16,7 +20,7 @@ def __del__(self):
def _run(self, your_own_input_example):
pass

def _call_vw(self, my_example, isTest, useOracle=False): # run_fn, setup_fn, takedown_fn, isTest):
def _call_vw(self, my_example, isTest, useOracle=False): # run_fn, setup_fn, takedown_fn, isTest):
self._output = None
self.bogus_example.set_test_only(isTest)
def run(): self._output = self._run(my_example)
Expand All @@ -30,22 +34,27 @@ def run(): self._output = self._run(my_example)
self.vw.learn(self.blank_line) # this will cause our ._run hook to get called

def learn(self, data_iterator):
"""Train search task by providing an iterator of examples"""
for my_example in data_iterator.__iter__():
self._call_vw(my_example, isTest=False);

def example(self, initStringOrDict=None, labelType=pylibvw.vw.lDefault):
"""TODO"""
"""Create an example
initStringOrDict can specify example as VW formatted string, or a dictionary
labelType can specify the desire label type"""
if self.sch.predict_needs_example():
return self.vw.example(initStringOrDict, labelType)
else:
return self.vw.example(None, labelType)

def predict(self, my_example, useOracle=False):
"""Return prediction"""
self._call_vw(my_example, isTest=True, useOracle=useOracle);
return self._output


def get_prediction(ec, prediction_type):
"""Get specified type of prediction from example"""
switch_prediction_type = {
pylibvw.vw.pSCALAR: ec.get_simplelabel_prediction,
pylibvw.vw.pSCALARS: ec.get_scalars,
Expand Down Expand Up @@ -383,6 +392,7 @@ def from_example(self, ex):


class simple_label(abstract_label):
"""Class for simple VW label"""
def __init__(self, label=0., weight=1., initial=0., prediction=0.):
abstract_label.__init__(self)
if isinstance(label, example):
Expand All @@ -407,6 +417,7 @@ def __str__(self):


class multiclass_label(abstract_label):
"""Class for multiclass VW label with prediction"""
def __init__(self, label=1, weight=1., prediction=1):
abstract_label.__init__(self)
if isinstance(label, example):
Expand All @@ -429,6 +440,7 @@ def __str__(self):


class multiclass_probabilities_label(abstract_label):
"""Class for multiclass VW label with probabilities"""
def __init__(self, label, prediction=None):
abstract_label.__init__(self)
if isinstance(label, example):
Expand All @@ -447,6 +459,7 @@ def __str__(self):


class cost_sensitive_label(abstract_label):
"""Class for cost sensative VW label"""
def __init__(self, costs=[], prediction=0):
abstract_label.__init__(self)
if isinstance(costs, example):
Expand Down Expand Up @@ -477,6 +490,7 @@ def __str__(self):


class cbandits_label(abstract_label):
"""Class for contextual bandits VW label"""
def __init__(self, costs=[], prediction=0):
abstract_label.__init__(self)
if isinstance(costs, example):
Expand Down
7 changes: 3 additions & 4 deletions python/vowpalwabbit/sklearn_vw.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# -*- coding: utf-8 -*-
# pylint: unused-argument, invalid-name, too-many-arguments, too-many-locals

"""
Utilities to support integration of Vowpal Wabbit and scikit-learn
"""
"""Utilities to support integration of Vowpal Wabbit and scikit-learn"""

import numpy as np
import re
Expand Down Expand Up @@ -447,9 +444,11 @@ def get_intercept(self):
return self.get_vw().get_weight(CONSTANT_HASH)

def save(self, filename):
"""Save model to file"""
joblib.dump(dict(params=self.get_params(), coefs=self.get_coefs(), fit=self.fit_), filename=filename)

def load(self, filename):
"""Load model from file"""
obj = joblib.load(filename=filename)
self.set_params(**obj['params'])
self.set_coefs(obj['coefs'])
Expand Down

0 comments on commit 9d95c81

Please sign in to comment.