blob: 326c49ccc278b093bcaabb248fb81c46a682d20c [file] [log] [blame]
Copybara854996b2021-09-07 19:36:02 +00001{
2 "cells": [
3 {
4 "cell_type": "code",
5 "execution_count": null,
6 "metadata": {
7 "collapsed": false
8 },
9 "outputs": [],
10 "source": [
11 "%pylab inline"
12 ]
13 },
14 {
15 "cell_type": "code",
16 "execution_count": null,
17 "metadata": {
18 "collapsed": true
19 },
20 "outputs": [],
21 "source": [
22 "from __future__ import print_function\n",
23 "from __future__ import division"
24 ]
25 },
26 {
27 "cell_type": "code",
28 "execution_count": null,
29 "metadata": {
30 "collapsed": false
31 },
32 "outputs": [],
33 "source": [
34 "import pandas as pd\n",
35 "import seaborn as sns\n",
36 "import pickle\n",
37 "import unicodedata\n",
38 "import time\n",
39 "import sklearn\n",
40 "from sklearn.preprocessing import MultiLabelBinarizer\n",
41 "from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer\n",
42 "from sklearn.svm import LinearSVC\n",
43 "from sklearn.cross_validation import train_test_split\n",
44 "from sklearn.multiclass import OneVsRestClassifier"
45 ]
46 },
47 {
48 "cell_type": "code",
49 "execution_count": null,
50 "metadata": {
51 "collapsed": false
52 },
53 "outputs": [],
54 "source": [
55 "issues = pickle.load(open(\"subset_issue.pkl\"))\n",
56 "comment_text = pickle.load(open(\"comment_text.pkl\"))"
57 ]
58 },
59 {
60 "cell_type": "markdown",
61 "metadata": {},
62 "source": [
63 "table for removing punctuation from text."
64 ]
65 },
66 {
67 "cell_type": "code",
68 "execution_count": null,
69 "metadata": {
70 "collapsed": true
71 },
72 "outputs": [],
73 "source": [
74 "table = dict.fromkeys(i for i in xrange(sys.maxunicode)\n",
75 " if unicodedata.category(unichr(i)).startswith('P'))"
76 ]
77 },
78 {
79 "cell_type": "markdown",
80 "metadata": {},
81 "source": [
82 "### Clean The text"
83 ]
84 },
85 {
86 "cell_type": "code",
87 "execution_count": null,
88 "metadata": {
89 "collapsed": false
90 },
91 "outputs": [],
92 "source": [
93 "def get_text_components_per_issue(issues):\n",
94 " text_per_issue = []\n",
95 " components_per_issue = []\n",
96 "\n",
97 " for index, row in issues.iterrows():\n",
98 " issue_text = \"\"\n",
99 " for comment_id in row[\"comments\"]:\n",
100 " text = comment_text[comment_id].strip()\n",
101 " # Remove punctuation\n",
102 " text = text.translate(table)\n",
103 " issue_text += text + \" \"\n",
104 " text_per_issue.append(issue_text.strip())\n",
105 "\n",
106 " components_per_issue.append(set(row[\"components\"]))\n",
107 " \n",
108 " return text_per_issue, components_per_issue\n",
109 " "
110 ]
111 },
112 {
113 "cell_type": "code",
114 "execution_count": null,
115 "metadata": {
116 "collapsed": false
117 },
118 "outputs": [],
119 "source": [
120 "text_per_issue, components_per_issue = get_text_components_per_issue(issues)"
121 ]
122 },
123 {
124 "cell_type": "markdown",
125 "metadata": {},
126 "source": [
127 "### Filter out components that are used infrequently(not enough singal) or too frequently (signal not meaningful)"
128 ]
129 },
130 {
131 "cell_type": "code",
132 "execution_count": null,
133 "metadata": {
134 "collapsed": false
135 },
136 "outputs": [],
137 "source": [
138 "def prune_and_bin_components(components_per_issue, prune_low=0.005, prune_high=0.25):\n",
139 " mlb = MultiLabelBinarizer()\n",
140 " bins = mlb.fit_transform(components_per_issue)\n",
141 " exclude_comp_ids = set(mlb.classes_[~(((bins.sum(axis=0) / bins.sum()) > prune_low) & \n",
142 " ((bins.sum(axis=0) / bins.sum()) < prune_high))])\n",
143 " \n",
144 " comps_per_issue_exclude = []\n",
145 " for comp_set in components_per_issue:\n",
146 " comps = comp_set - exclude_comp_ids\n",
147 " comps_per_issue_exclude.append(comps)\n",
148 " \n",
149 " mlb = MultiLabelBinarizer()\n",
150 " bins = mlb.fit_transform(comps_per_issue_exclude)\n",
151 " return bins, mlb"
152 ]
153 },
154 {
155 "cell_type": "code",
156 "execution_count": null,
157 "metadata": {
158 "collapsed": false
159 },
160 "outputs": [],
161 "source": [
162 "bins, mlb = prune_and_bin_components(components_per_issue)"
163 ]
164 },
165 {
166 "cell_type": "markdown",
167 "metadata": {},
168 "source": [
169 "### Tokenize the text and perform tfidf transformations"
170 ]
171 },
172 {
173 "cell_type": "code",
174 "execution_count": null,
175 "metadata": {
176 "collapsed": false
177 },
178 "outputs": [],
179 "source": [
180 "bigram_vectorizer = CountVectorizer(ngram_range=(1, 2),\n",
181 " token_pattern=r'\\b\\w+\\b',\n",
182 " min_df=5,\n",
183 " max_df=0.5,\n",
184 " stop_words='english')\n",
185 "\n",
186 "tfidf_transformer = TfidfTransformer()"
187 ]
188 },
189 {
190 "cell_type": "code",
191 "execution_count": null,
192 "metadata": {
193 "collapsed": false
194 },
195 "outputs": [],
196 "source": [
197 "counts = bigram_vectorizer.fit_transform(text_per_issue)\n",
198 "tfidf = tfidf_transformer.fit_transform(counts)"
199 ]
200 },
201 {
202 "cell_type": "code",
203 "execution_count": null,
204 "metadata": {
205 "collapsed": true
206 },
207 "outputs": [],
208 "source": [
209 "X_train, X_test, y_train, y_test = train_test_split(tfidf, bins, train_size=0.8, random_state=42)"
210 ]
211 },
212 {
213 "cell_type": "markdown",
214 "metadata": {},
215 "source": [
216 "### Train a very simple linear model"
217 ]
218 },
219 {
220 "cell_type": "code",
221 "execution_count": null,
222 "metadata": {
223 "collapsed": false
224 },
225 "outputs": [],
226 "source": [
227 "clf = OneVsRestClassifier(LinearSVC(C=1.0))\n",
228 "clf.fit(X_train, y_train)"
229 ]
230 },
231 {
232 "cell_type": "markdown",
233 "metadata": {},
234 "source": [
235 "### Predict and analyze the results"
236 ]
237 },
238 {
239 "cell_type": "code",
240 "execution_count": null,
241 "metadata": {
242 "collapsed": false
243 },
244 "outputs": [],
245 "source": [
246 "predictions = clf.predict(X_test)"
247 ]
248 },
249 {
250 "cell_type": "code",
251 "execution_count": null,
252 "metadata": {
253 "collapsed": false
254 },
255 "outputs": [],
256 "source": [
257 "(y_test == predictions).sum() / (y_test.shape[0] * y_test.shape[1])"
258 ]
259 },
260 {
261 "cell_type": "code",
262 "execution_count": null,
263 "metadata": {
264 "collapsed": false
265 },
266 "outputs": [],
267 "source": [
268 "np.sum((y_test == predictions).sum(axis=1) == 44) / y_test.shape[0]"
269 ]
270 },
271 {
272 "cell_type": "code",
273 "execution_count": null,
274 "metadata": {
275 "collapsed": false
276 },
277 "outputs": [],
278 "source": [
279 "sns.distplot(y_test.sum(axis=1), kde=False)\n",
280 "sns.distplot(predictions.sum(axis=1), kde=False)"
281 ]
282 },
283 {
284 "cell_type": "code",
285 "execution_count": null,
286 "metadata": {
287 "collapsed": false
288 },
289 "outputs": [],
290 "source": [
291 "sns.barplot(range(44), y_test.sum(axis=0), color=\"red\")\n",
292 "sns.barplot(range(44), predictions.sum(axis=0), color=\"blue\")"
293 ]
294 },
295 {
296 "cell_type": "markdown",
297 "metadata": {},
298 "source": [
299 "### Serialize the data and the model"
300 ]
301 },
302 {
303 "cell_type": "code",
304 "execution_count": null,
305 "metadata": {
306 "collapsed": false
307 },
308 "outputs": [],
309 "source": [
310 "def serialize_data_model(vectorizer, classifier, features, targets, transformer=None):\n",
311 " current_time = int(time.time())\n",
312 " pickle.dump(vectorizer, open(\"{}-vectorizer.pkl\".format(current_time), \"wb\"))\n",
313 " pickle.dump(classifier, open(\"{}-classifier.pkl\".format(current_time), \"wb\"))\n",
314 " \n",
315 " training = {\"features\": features, \"targets\": targets}\n",
316 " pickle.dump(training, open(\"{}.pkl\".format(current_time), \"wb\"))\n",
317 " \n",
318 " if transformer:\n",
319 " pickle.dump(transformer, open(\"{}-transformer.pkl\".format(current_time), \"wb\"))\n",
320 " "
321 ]
322 },
323 {
324 "cell_type": "code",
325 "execution_count": null,
326 "metadata": {
327 "collapsed": false
328 },
329 "outputs": [],
330 "source": [
331 "serialize_data_model(bigram_vectorizer, clf, tfidf, bins, tfidf_transformer)"
332 ]
333 },
334 {
335 "cell_type": "code",
336 "execution_count": null,
337 "metadata": {
338 "collapsed": true
339 },
340 "outputs": [],
341 "source": []
342 }
343 ],
344 "metadata": {
345 "kernelspec": {
346 "display_name": "Python 2",
347 "language": "python",
348 "name": "python2"
349 },
350 "language_info": {
351 "codemirror_mode": {
352 "name": "ipython",
353 "version": 2
354 },
355 "file_extension": ".py",
356 "mimetype": "text/x-python",
357 "name": "python",
358 "nbconvert_exporter": "python",
359 "pygments_lexer": "ipython2",
360 "version": "2.7.6"
361 }
362 },
363 "nbformat": 4,
364 "nbformat_minor": 0
365}