Project import generated by Copybara.
GitOrigin-RevId: d9e9e3fb4e31372ec1fb43b178994ca78fa8fe70
diff --git a/tools/datalab/simple_end_to_end.ipynb b/tools/datalab/simple_end_to_end.ipynb
new file mode 100644
index 0000000..326c49c
--- /dev/null
+++ b/tools/datalab/simple_end_to_end.ipynb
@@ -0,0 +1,365 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%pylab inline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "from __future__ import print_function\n",
+ "from __future__ import division"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import seaborn as sns\n",
+ "import pickle\n",
+ "import unicodedata\n",
+ "import time\n",
+ "import sklearn\n",
+ "from sklearn.preprocessing import MultiLabelBinarizer\n",
+ "from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer\n",
+ "from sklearn.svm import LinearSVC\n",
+ "from sklearn.cross_validation import train_test_split\n",
+ "from sklearn.multiclass import OneVsRestClassifier"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "issues = pickle.load(open(\"subset_issue.pkl\"))\n",
+ "comment_text = pickle.load(open(\"comment_text.pkl\"))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "table for removing punctuation from text."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "table = dict.fromkeys(i for i in xrange(sys.maxunicode)\n",
+ " if unicodedata.category(unichr(i)).startswith('P'))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Clean The text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "def get_text_components_per_issue(issues):\n",
+ " text_per_issue = []\n",
+ " components_per_issue = []\n",
+ "\n",
+ " for index, row in issues.iterrows():\n",
+ " issue_text = \"\"\n",
+ " for comment_id in row[\"comments\"]:\n",
+ " text = comment_text[comment_id].strip()\n",
+ " # Remove punctuation\n",
+ " text = text.translate(table)\n",
+ " issue_text += text + \" \"\n",
+ " text_per_issue.append(issue_text.strip())\n",
+ "\n",
+ " components_per_issue.append(set(row[\"components\"]))\n",
+ " \n",
+ " return text_per_issue, components_per_issue\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "text_per_issue, components_per_issue = get_text_components_per_issue(issues)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Filter out components that are used infrequently(not enough singal) or too frequently (signal not meaningful)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "def prune_and_bin_components(components_per_issue, prune_low=0.005, prune_high=0.25):\n",
+ " mlb = MultiLabelBinarizer()\n",
+ " bins = mlb.fit_transform(components_per_issue)\n",
+ " exclude_comp_ids = set(mlb.classes_[~(((bins.sum(axis=0) / bins.sum()) > prune_low) & \n",
+ " ((bins.sum(axis=0) / bins.sum()) < prune_high))])\n",
+ " \n",
+ " comps_per_issue_exclude = []\n",
+ " for comp_set in components_per_issue:\n",
+ " comps = comp_set - exclude_comp_ids\n",
+ " comps_per_issue_exclude.append(comps)\n",
+ " \n",
+ " mlb = MultiLabelBinarizer()\n",
+ " bins = mlb.fit_transform(comps_per_issue_exclude)\n",
+ " return bins, mlb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "bins, mlb = prune_and_bin_components(components_per_issue)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Tokenize the text and perform tfidf transformations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "bigram_vectorizer = CountVectorizer(ngram_range=(1, 2),\n",
+ " token_pattern=r'\\b\\w+\\b',\n",
+ " min_df=5,\n",
+ " max_df=0.5,\n",
+ " stop_words='english')\n",
+ "\n",
+ "tfidf_transformer = TfidfTransformer()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "counts = bigram_vectorizer.fit_transform(text_per_issue)\n",
+ "tfidf = tfidf_transformer.fit_transform(counts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": [
+ "X_train, X_test, y_train, y_test = train_test_split(tfidf, bins, train_size=0.8, random_state=42)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Train a very simple linear model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "clf = OneVsRestClassifier(LinearSVC(C=1.0))\n",
+ "clf.fit(X_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Predict and analyze the results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "predictions = clf.predict(X_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "(y_test == predictions).sum() / (y_test.shape[0] * y_test.shape[1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "np.sum((y_test == predictions).sum(axis=1) == 44) / y_test.shape[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "sns.distplot(y_test.sum(axis=1), kde=False)\n",
+ "sns.distplot(predictions.sum(axis=1), kde=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "sns.barplot(range(44), y_test.sum(axis=0), color=\"red\")\n",
+ "sns.barplot(range(44), predictions.sum(axis=0), color=\"blue\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Serialize the data and the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "def serialize_data_model(vectorizer, classifier, features, targets, transformer=None):\n",
+ " current_time = int(time.time())\n",
+ " pickle.dump(vectorizer, open(\"{}-vectorizer.pkl\".format(current_time), \"wb\"))\n",
+ " pickle.dump(classifier, open(\"{}-classifier.pkl\".format(current_time), \"wb\"))\n",
+ " \n",
+ " training = {\"features\": features, \"targets\": targets}\n",
+ " pickle.dump(training, open(\"{}.pkl\".format(current_time), \"wb\"))\n",
+ " \n",
+ " if transformer:\n",
+ " pickle.dump(transformer, open(\"{}-transformer.pkl\".format(current_time), \"wb\"))\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "serialize_data_model(bigram_vectorizer, clf, tfidf, bins, tfidf_transformer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "language": "python",
+ "name": "python2"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}