Skip to content

Modules

main(verbose, covid, inp, out, csv, num, rec, unstructured, filters, codedict, topics, assign, cat, summary, sentiment, sentence, nlp, nnet, cls, knn, kmeans, cart, pca, regression, lstm, ml, visualize, ignore, include, outcome, source, sources, print_args, clear)

CRISP-T: Cross Industry Standard Process for Triangulation.

A comprehensive framework for analyzing textual and numerical data using advanced NLP, machine learning, and statistical techniques.

Source code in src/crisp_t/cli.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
@click.command()
@click.option("--verbose", "-v", is_flag=True, help="Print verbose messages.")
@click.option(
    "--covid", "-cf", default="", help="Download COVID narratives from the website"
)
@click.option("--inp", "-i", help="Load corpus from a folder containing corpus.json")
@click.option("--out", "-o", help="Write corpus to a folder as corpus.json")
@click.option("--csv", default="", help="CSV file name")
@click.option(
    "--num", "-n", default=3, help="N (clusters/epochs, etc, depending on context)"
)
@click.option("--rec", "-r", default=3, help="Record or top_n (based on context)")
@click.option(
    "--unstructured",
    "-t",
    multiple=True,
    help="Csv columns with text data that needs to be treated as text. (Ex. Free text comments)",
)
@click.option(
    "--filters",
    "-f",
    multiple=True,
    help="Filters to apply as key=value (can be used multiple times)",
)
@click.option("--codedict", is_flag=True, help="Generate coding dictionary")
@click.option("--topics", is_flag=True, help="Generate topic model")
@click.option("--assign", is_flag=True, help="Assign documents to topics")
@click.option(
    "--cat", is_flag=True, help="List categories of entire corpus or individual docs"
)
@click.option(
    "--summary",
    is_flag=True,
    help="Generate summary for entire corpus or individual docs",
)
@click.option(
    "--sentiment",
    is_flag=True,
    help="Generate sentiment score for entire corpus or individual docs",
)
@click.option(
    "--sentence",
    is_flag=True,
    default=False,
    help="Generate sentence-level scores when applicable",
)
@click.option("--nlp", is_flag=True, help="Generate all NLP reports")
@click.option("--ml", is_flag=True, help="Generate all ML reports")
@click.option("--nnet", is_flag=True, help="Display accuracy of a neural network model")
@click.option(
    "--cls",
    is_flag=True,
    help="Display confusion matrix from classifiers (SVM, Decision Tree)",
)
@click.option("--knn", is_flag=True, help="Display nearest neighbours")
@click.option("--kmeans", is_flag=True, help="Display KMeans clusters")
@click.option("--cart", is_flag=True, help="Display Association Rules")
@click.option("--pca", is_flag=True, help="Display PCA")
@click.option(
    "--regression", is_flag=True, help="Display linear or logistic regression results"
)
@click.option("--lstm", is_flag=True, help="Train LSTM model on text data to predict outcome variable")
@click.option("--visualize", is_flag=True, help="Visualize words, topics or wordcloud")
@click.option(
    "--ignore",
    default="",
    help="Comma separated ignore words or columns depending on context",
)
@click.option(
    "--include", default="", help="Comma separated columns to include from csv"
)
@click.option("--outcome", default="", help="Outcome variable for ML tasks")
@click.option("--source", "-s", help="Source URL or directory path to read data from")
@click.option("--print", "-p", "print_args", multiple=True, help="Display corpus information. Usage: --print documents --print 10, or quoted: --print 'documents 10'")
@click.option(
    "--sources",
    multiple=True,
    help="Multiple sources (URLs or directories) to read data from; can be used multiple times",
)
@click.option("--clear", is_flag=True, help="Clear cache before running analysis")
def main(
    verbose,
    covid,
    inp,
    out,
    csv,
    num,
    rec,
    unstructured,
    filters,
    codedict,
    topics,
    assign,
    cat,
    summary,
    sentiment,
    sentence,
    nlp,
    nnet,
    cls,
    knn,
    kmeans,
    cart,
    pca,
    regression,
    lstm,
    ml,
    visualize,
    ignore,
    include,
    outcome,
    source,
    sources,
    print_args,
    clear,
):
    """CRISP-T: Cross Industry Standard Process for Triangulation.

    A comprehensive framework for analyzing textual and numerical data using
    advanced NLP, machine learning, and statistical techniques.
    """

    if verbose:
        logging.getLogger().setLevel(logging.DEBUG)
        click.echo("Verbose mode enabled")

    click.echo("_________________________________________")
    click.echo("CRISP-T: Qualitative Research Analysis Framework")
    click.echo(f"Version: {__version__}")
    click.echo("_________________________________________")

    # Initialize components
    read_data = ReadData()
    corpus = None
    text_analyzer = None
    csv_analyzer = None
    ml_analyzer = None

    if clear:
        _clear_cache()

    try:
        # Handle COVID data download
        if covid:
            if not source:
                raise click.ClickException(
                    "--source (output folder) is required when using --covid."
                )
            click.echo(f"Downloading COVID narratives from: {covid} to {source}")
            try:
                from .utils import QRUtils

                QRUtils.read_covid_narratives(source, covid)
                click.echo(f"✓ COVID narratives downloaded to {source}")
            except Exception as e:
                raise click.ClickException(f"COVID download failed: {e}")

        # Build corpus using helpers (source preferred over inp)
        # if not source or inp, use default folders or env vars
        try:
            text_cols = ",".join(unstructured) if unstructured else ""
            corpus = initialize_corpus(
                source=source,
                inp=inp,
                comma_separated_text_columns=text_cols,
                comma_separated_ignore_words=(ignore if ignore else None),
            )
            # If filters were provided with ':' while using --source, emit guidance message
            if source and filters:
                if any(":" in flt and "=" not in flt for flt in filters):
                    click.echo("Filters are not supported when using --source")
        except click.ClickException:
            raise
        except Exception as e:
            click.echo(f"✗ Error initializing corpus: {e}", err=True)
            logger.error(f"Failed to initialize corpus: {e}")
            return

        # Handle multiple sources (unchanged behavior, but no filters applied here)
        if sources and not corpus:
            loaded_any = False
            for src in sources:
                click.echo(f"Reading data from source: {src}")
                try:
                    read_data.read_source(
                        src, comma_separated_ignore_words=ignore if ignore else None
                    )
                    loaded_any = True
                except Exception as e:
                    logger.error(f"Failed to read source {src}: {e}")
                    raise click.ClickException(str(e))

            if loaded_any:
                corpus = read_data.create_corpus(
                    name="Corpus from multiple sources",
                    description=f"Data loaded from {len(sources)} sources",
                )
                click.echo(
                    f"✓ Successfully loaded {len(corpus.documents)} document(s) from {len(sources)} sources"
                )
                # Filters are not applied for --sources in bulk mode

        # Load csv from corpus.df if available via helper
        if corpus and getattr(corpus, "df", None) is not None:
            try:
                text_cols = ",".join(unstructured) if unstructured else ""
                csv_analyzer = get_csv_analyzer(
                    corpus,
                    comma_separated_unstructured_text_columns=text_cols,
                    comma_separated_ignore_columns=(ignore if ignore else ""),
                    filters=filters,
                )
            except Exception as e:
                click.echo(f"✗ Error preparing CSV analyzer: {e}", err=True)
                logger.error(f"Failed to create CSV analyzer: {e}")
                return

        # Load CSV data (deprecated)
        if csv:
            click.echo(
                "--csv option has been deprecated. Put csv file in --source folder instead."
            )

        # Initialize ML analyzer if available and ML functions are requested
        if (
            ML_AVAILABLE
            and (nnet or cls or knn or kmeans or cart or pca or regression or lstm or ml)
            and csv_analyzer
        ):
            if include:
                csv_analyzer.comma_separated_include_columns(include)
            ml_analyzer = ML(csv=csv_analyzer)  # type: ignore
        else:
            if (nnet or cls or knn or kmeans or cart or pca or regression or lstm or ml) and not ML_AVAILABLE:
                click.echo("Machine learning features require additional dependencies.")
                click.echo("Install with: pip install crisp-t[ml]")
            if (nnet or cls or knn or kmeans or cart or pca or regression or lstm or ml) and not csv_analyzer:
                click.echo(
                    "ML analysis requires CSV data. Use --csv to provide a data file."
                )

        # Initialize Text analyzer and apply filters using helper if we have a corpus
        if corpus and not text_analyzer:
            text_analyzer = get_text_analyzer(corpus, filters=filters)

        # Ensure we have data to work with
        if not corpus and not csv_analyzer:
            click.echo(
                "No input data provided. Use --inp for text files"
            )
            return

        # Text Analysis Operations
        if text_analyzer:
            if nlp or codedict:
                click.echo("\n=== Generating Coding Dictionary ===")
                click.echo(
                    """
                Coding Dictionary Format:
                - CATEGORY: Common verbs representing main actions or themes.
                - PROPERTY: Common nouns associated with each CATEGORY.
                - DIMENSION: Common adjectives, adverbs, or verbs associated with each PROPERTY.

                Hint:   Use --ignore with a comma-separated list of words to exclude common but uninformative words.
                        Use --filters to narrow down documents based on metadata.
                        Use --num to adjust the number of categories displayed.
                        Use --rec to adjust the number of top items displayed per section.
                """
                )
                try:
                    text_analyzer.make_spacy_doc()
                    coding_dict = text_analyzer.print_coding_dictionary(
                        num=num, top_n=rec
                    )
                    if out:
                        _save_output(coding_dict, out, "coding_dictionary")
                except Exception as e:
                    click.echo(f"Error generating coding dictionary: {e}")

            if nlp or topics:
                click.echo("\n=== Topic Modeling ===")
                click.echo(
                    """
                Topic Modeling Output Format:
                Each topic is represented as a list of words with associated weights indicating their importance within the topic.
                Example:
                Topic 0: 0.116*"category" + 0.093*"comparison" + 0.070*"incident" + ...
                Hint:   Use --num to adjust the number of topics generated.
                        Use --filters to narrow down documents based on metadata.
                        Use --rec to adjust the number of words displayed per topic.
                """
                )
                try:
                    cluster_analyzer = Cluster(corpus=corpus)
                    cluster_analyzer.build_lda_model(topics=num)
                    topics_result = cluster_analyzer.print_topics(num_words=rec)
                    click.echo(
                        f"Generated {len(topics_result)} topics as above with the weights in brackets."
                    )
                    if out:
                        _save_output(topics_result, out, "topics")
                except Exception as e:
                    click.echo(f"Error generating topics: {e}")

            if nlp or assign:
                click.echo("\n=== Document-Topic Assignments ===")
                click.echo(
                    """
                Document-Topic Assignment Format:
                Each document is assigned to the topic it is most associated with, along with the contribution percentage.
                Hint: --visualize adds a DataFrame to corpus.visualization["assign_topics"] for visualization.
                """
                )
                try:
                    if "cluster_analyzer" not in locals():
                        cluster_analyzer = Cluster(corpus=corpus)
                        cluster_analyzer.build_lda_model(topics=num)
                    assignments = cluster_analyzer.format_topics_sentences(
                        visualize=visualize
                    )
                    document_assignments = cluster_analyzer.print_clusters()
                    click.echo(f"Assigned {len(assignments)} documents to topics")
                    if out:
                        _save_output(assignments, out, "topic_assignments")
                except Exception as e:
                    click.echo(f"Error assigning topics: {e}")

            if nlp or cat:
                click.echo("\n=== Category Analysis ===")
                click.echo(
                    """
                Category Analysis Output Format:
                           A list of common concepts or themes in "bag_of_terms" with corresponding weights.
                Hint:   Use --num to adjust the number of categories displayed.
                        Use --filters to narrow down documents based on metadata.
                """
                )
                try:
                    text_analyzer.make_spacy_doc()
                    categories = text_analyzer.print_categories(num=num)
                    if out:
                        _save_output(categories, out, "categories")
                except Exception as e:
                    click.echo(f"Error generating categories: {e}")

            if nlp or summary:
                click.echo("\n=== Text Summarization ===")
                click.echo(
                    """
                Text Summarization Output Format: A list of important sentences representing the main points of the text.
                Hint:   Use --num to adjust the number of sentences in the summary.
                        Use --filters to narrow down documents based on metadata.
                """
                )
                try:
                    text_analyzer.make_spacy_doc()
                    summary_result = text_analyzer.generate_summary(weight=num)
                    click.echo(summary_result)
                    if out:
                        _save_output(summary_result, out, "summary")
                except Exception as e:
                    click.echo(f"Error generating summary: {e}")

            if nlp or sentiment:
                click.echo("\n=== Sentiment Analysis ===")
                click.echo(
                    """
                Sentiment Analysis Output Format:
                           neg, neu, pos, compound scores.
                Hint:   Use --filters to narrow down documents based on metadata.
                        Use --sentence to get document-level sentiment scores.
                """
                )
                try:
                    sentiment_analyzer = Sentiment(corpus=corpus)  # type: ignore
                    sentiment_results = sentiment_analyzer.get_sentiment(
                        documents=sentence, verbose=verbose
                    )
                    click.echo(sentiment_results)
                    if out:
                        _save_output(sentiment_results, out, "sentiment")
                except Exception as e:
                    click.echo(f"Error generating sentiment analysis: {e}")

        # Machine Learning Operations
        if ml_analyzer and ML_AVAILABLE:
            target_col = outcome

            if kmeans or ml:
                click.echo("\n=== K-Means Clustering ===")
                click.echo(
                    """
                           K-Means clustering removes non-numeric columns.
                           Additionally it removes NaN values.
                           So combining with other ML options may not work as expected.
                Hint:   Use --num to adjust the number of clusters generated.
                """
                )
                csv_analyzer.retain_numeric_columns_only()
                csv_analyzer.drop_na()
                _ml_analyzer = ML(csv=csv_analyzer)
                clusters, members = _ml_analyzer.get_kmeans(
                    number_of_clusters=num, verbose=verbose
                )
                _ml_analyzer.profile(members, number_of_clusters=num)
                if out:
                    _save_output(
                        {"clusters": clusters, "members": members}, out, "kmeans"
                    )

            if (cls or ml) and target_col:
                click.echo("\n=== Classifier Evaluation ===")
                click.echo(
                    """
                           Classifier
                            - SVM: Support Vector Machine classifier with confusion matrix output.
                            - Decision Tree: Decision Tree classifier with feature importance output.
                Hint:   Use --outcome to specify the target variable for classification.
                        Use --rec to adjust the number of top important features displayed.
                        Use --include to specify columns to include in the analysis (comma separated).
                """
                )
                if not target_col:
                    raise click.ClickException(
                        "--outcome is required for classification tasks"
                    )
                click.echo("\n=== SVM ===")
                try:
                    confusion_matrix = ml_analyzer.svm_confusion_matrix(
                        y=target_col, test_size=0.25
                    )
                    click.echo(
                        ml_analyzer.format_confusion_matrix_to_human_readable(
                            confusion_matrix
                        )
                    )
                    if out:
                        _save_output(confusion_matrix, out, "svm_results")
                except Exception as e:
                    click.echo(f"Error performing SVM classification: {e}")
                click.echo("\n=== Decision Tree Classification ===")
                try:
                    cm, importance = ml_analyzer.get_decision_tree_classes(
                        y=target_col, top_n=rec
                    )
                    click.echo("\n=== Feature Importance ===")
                    click.echo(
                        ml_analyzer.format_confusion_matrix_to_human_readable(cm)
                    )
                    if out:
                        _save_output(cm, out, "decision_tree_results")
                except Exception as e:
                    click.echo(f"Error performing Decision Tree classification: {e}")

            if (nnet or ml) and target_col:
                click.echo("\n=== Neural Network Classification Accuracy ===")
                click.echo(
                    """
                            Neural Network classifier with accuracy output.
                Hint:   Use --outcome to specify the target variable for classification.
                        Use --include to specify columns to include in the analysis (comma separated).
                """
                )
                if not target_col:
                    raise click.ClickException(
                        "--outcome is required for neural network tasks"
                    )
                try:
                    predictions = ml_analyzer.get_nnet_predictions(y=target_col)
                    if out:
                        _save_output(predictions, out, "nnet_results")
                except Exception as e:
                    click.echo(f"Error performing Neural Network classification: {e}")

            if (knn or ml) and target_col:
                click.echo("\n=== K-Nearest Neighbors ===")
                click.echo(
                    """
                           K-Nearest Neighbors search results.
                Hint:   Use --outcome to specify the target variable for KNN search.
                        Use --rec to specify the record number to search from (1-based index).
                        Use --num to specify the number of nearest neighbors to retrieve.
                        Use --include to specify columns to include in the analysis (comma separated).
                """
                )
                if not target_col:
                    raise click.ClickException(
                        "--outcome is required for KNN search tasks"
                    )
                if rec < 1:
                    raise click.ClickException(
                        "--rec must be a positive integer (1-based index)"
                    )
                try:
                    knn_results = ml_analyzer.knn_search(y=target_col, n=num, r=rec)
                    if out:
                        _save_output(knn_results, out, "knn_results")
                except Exception as e:
                    click.echo(f"Error performing K-Nearest Neighbors search: {e}")

            if (cart or ml) and target_col:
                click.echo("\n=== Association Rules (CART) ===")
                click.echo(
                    """
                           Association Rules using the Apriori algorithm.
                Hint:   Use --outcome to specify the target variable to remove from features.
                        Use --num to specify the minimum support (between 1 and 99).
                        Use --rec to specify the minimum threshold for the rules (between 1 and 99).
                        Use --include to specify columns to include in the analysis (comma separated).
                """
                )
                if not target_col:
                    raise click.ClickException(
                        "--outcome is required for association rules tasks"
                    )
                if not (1 <= num <= 99):
                    raise click.ClickException(
                        "--num must be between 1 and 99 for min_support"
                    )
                if not (1 <= rec <= 99):
                    raise click.ClickException(
                        "--rec must be between 1 and 99 for min_threshold"
                    )
                _min_support = float(num / 100)
                _min_threshold = float(rec / 100)
                click.echo(
                    f"Using min_support={_min_support:.2f} and min_threshold={_min_threshold:.2f}"
                )
                try:
                    apriori_results = ml_analyzer.get_apriori(
                        y=target_col,
                        min_support=_min_support,
                        min_threshold=_min_threshold,
                    )
                    click.echo(apriori_results)
                    if out:
                        _save_output(apriori_results, out, "association_rules")
                except Exception as e:
                    click.echo(f"Error generating association rules: {e}")

            if (pca or ml) and target_col:
                click.echo("\n=== Principal Component Analysis ===")
                click.echo(
                    """
                           Principal Component Analysis (PCA) results.
                Hint:   Use --outcome to specify the target variable to remove from features.
                        Use --num to specify the number of principal components to generate.
                        Use --include to specify columns to include in the analysis (comma separated).
                """
                )
                try:
                    pca_results = ml_analyzer.get_pca(y=target_col, n=num)
                    if out:
                        _save_output(pca_results, out, "pca_results")
                except Exception as e:
                    click.echo(f"Error performing Principal Component Analysis: {e}")

            if (regression or ml) and target_col:
                click.echo("\n=== Regression Analysis ===")
                click.echo(
                    """
                           Regression Analysis (Linear or Logistic Regression).
                           Automatically detects binary outcomes for logistic regression.
                           Otherwise uses linear regression for continuous outcomes.
                Hint:   Use --outcome to specify the target variable for regression.
                        Use --include to specify columns to include in the analysis (comma separated).
                """
                )
                try:
                    regression_results = ml_analyzer.get_regression(y=target_col)
                    if out:
                        _save_output(regression_results, out, "regression_results")
                except Exception as e:
                    click.echo(f"Error performing regression analysis: {e}")

            if (lstm or ml) and target_col:
                click.echo("\n=== LSTM Text Classification ===")
                click.echo(
                    """
                           LSTM (Long Short-Term Memory) model for text-based prediction.
                           Tests if text documents converge towards predicting the outcome variable.
                           Requires both text documents and an 'id' column to align texts with outcome.
                Hint:   Use --outcome to specify the target variable for LSTM prediction.
                        The outcome should be binary (two classes).
                        Ensure documents have IDs matching the 'id' column in your data.
                """
                )
                if not target_col:
                    raise click.ClickException(
                        "--outcome is required for LSTM prediction tasks"
                    )
                try:
                    lstm_results = ml_analyzer.get_lstm_predictions(y=target_col)
                    if out:
                        _save_output(lstm_results, out, "lstm_results")
                except Exception as e:
                    click.echo(f"Error performing LSTM prediction: {e}")

        elif (nnet or cls or knn or kmeans or cart or pca or regression or lstm or ml) and not ML_AVAILABLE:
            click.echo("Machine learning features require additional dependencies.")
            click.echo("Install with: pip install crisp-t[ml]")

        # Save corpus and csv if output path is specified
        if out and corpus:
            if filters and inp and out and inp == out:
                raise click.ClickException(
                    "--out cannot be the same as --inp when using --filters. Please specify a different output folder to avoid overwriting input data."
                )
            if filters and ((not inp) or (not out)):
                raise click.ClickException(
                    "Both --inp and --out must be specified when using --filters."
                )
            output_path = pathlib.Path(out)
            # Allow both directory and a file path '.../corpus.json'
            if output_path.suffix:
                # Ensure parent exists
                output_path.parent.mkdir(parents=True, exist_ok=True)
                save_base = output_path
            else:
                output_path.mkdir(parents=True, exist_ok=True)
                save_base = output_path / "corpus.json"
            read_data.write_corpus_to_json(str(save_base), corpus=corpus)
            click.echo(f"✓ Corpus and csv saved to {save_base}")

        if print_args and corpus:
            click.echo("\n=== Corpus Details ===")
            # Join the print arguments into a single string
            print_command = " ".join(print_args) if print_args else None
            if print_command:
                click.echo(corpus.pretty_print(show=print_command))

        click.echo("\n=== Analysis Complete ===")

    except click.ClickException:
        # Let Click handle and set non-zero exit code
        raise
    except Exception as e:
        # Convert unexpected exceptions to ClickException for non-zero exit code
        if verbose:
            import traceback

            traceback.print_exc()
        raise click.ClickException(str(e))

Copyright (C) 2025 Bell Eapen

This file is part of crisp-t.

crisp-t is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

crisp-t is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with crisp-t. If not, see https://www.gnu.org/licenses/.

ReadData

Source code in src/crisp_t/read_data.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
class ReadData:

    def __init__(self, corpus: Corpus | None = None, source=None):
        self._corpus = corpus
        self._source = source
        self._documents = []
        self._df = pd.DataFrame()

    @property
    def corpus(self):
        """
        Get the corpus.
        """
        if not self._corpus:
            raise ValueError("No corpus found. Please create a corpus first.")
        self._corpus.documents = self._documents
        self._corpus.df = self._df
        return self._corpus

    @property
    def documents(self):
        """
        Get the documents.
        """
        if not self._documents:
            raise ValueError("No documents found. Please read data first.")
        return self._documents

    @property
    def df(self):
        """
        Get the dataframe.
        """
        if self._df is None:
            raise ValueError("No dataframe found. Please read data first.")
        return self._df

    @corpus.setter
    def corpus(self, value):
        """
        Set the corpus.
        """
        if not isinstance(value, Corpus):
            raise ValueError("Value must be a Corpus object.")
        self._corpus = value

    @documents.setter
    def documents(self, value):
        """
        Set the documents.
        """
        if not isinstance(value, list):
            raise ValueError("Value must be a list of Document objects.")
        for document in value:
            if not isinstance(document, Document):
                raise ValueError("Value must be a list of Document objects.")
        self._documents = value

    @df.setter
    def df(self, value):
        """
        Set the dataframe.
        """
        if not isinstance(value, pd.DataFrame):
            raise ValueError("Value must be a pandas DataFrame.")
        self._df = value

    def pretty_print(self):
        """
        Pretty print the corpus.
        """
        if not self._corpus:
            self.create_corpus()
        if self._corpus:
            print(
                self._corpus.model_dump_json(indent=4, exclude={"df", "visualization"})
            )
            logger.info(
                "Corpus: %s",
                self._corpus.model_dump_json(indent=4, exclude={"df", "visualization"}),
            )
        else:
            logger.error("No corpus available to pretty print.")

    # TODO: Enforce only one corpus (Singleton pattern)
    def create_corpus(self, name=None, description=None):
        """
        Create a corpus from the documents and dataframe.
        """
        if not self._documents:
            raise ValueError("No documents found. Please read data first.")
        if self._corpus:
            self._corpus.documents = self._documents
            self._corpus.df = self._df
        else:
            timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            corpus_id = f"corpus_{timestamp}"
            self._corpus = Corpus(
                documents=self._documents,
                df=self._df,
                visualization={},
                metadata={},
                id=corpus_id,
                score=0.0,
                name=name,
                description=description,
            )
        return self._corpus

    def get_documents_from_corpus(self):
        """
        Get the documents from the corpus.
        """
        if not self._corpus:
            raise ValueError("No corpus found. Please create a corpus first.")
        return self._corpus.documents

    def get_document_by_id(self, doc_id):
        """
        Get a document from the corpus by its ID. Uses parallel search for large corpora.
        """
        if not self._corpus:
            raise ValueError("No corpus found. Please create a corpus first.")
        documents = self._corpus.documents
        if len(documents) < 10:
            for document in tqdm(documents, desc="Searching documents", disable=True):
                if document.id == doc_id:
                    return document
        else:
            n_cores = multiprocessing.cpu_count()
            with ThreadPoolExecutor() as executor:
                futures = {
                    executor.submit(lambda doc: doc.id == doc_id, document): i
                    for i, document in enumerate(documents)
                }
                with tqdm(
                    total=len(futures),
                    desc=f"Searching documents (parallel, {n_cores} cores)",
                ) as pbar:
                    for future in as_completed(futures):
                        i = futures[future]
                        found = future.result()
                        pbar.update(1)
                        if found:
                            return documents[i]
        raise ValueError("Document not found: %s" % doc_id)

    def write_corpus_to_json(self, file_path="", corpus=None):
        """
        Write the corpus to a json file.

        Accepts either a directory path or an explicit file path ending with
        'corpus.json'. In both cases, a sibling 'corpus_df.csv' will be written
        next to the json if a DataFrame is available.
        """
        from pathlib import Path

        path = Path(file_path)
        # Determine targets
        if path.suffix:  # treat as explicit file path
            file_name = path
            df_name = path.with_name("corpus_df.csv")
        else:
            file_name = path / "corpus.json"
            df_name = path / "corpus_df.csv"

        corp = corpus if corpus is not None else self._corpus
        if not corp:
            raise ValueError("No corpus found. Please create a corpus first.")
        file_name.parent.mkdir(parents=True, exist_ok=True)
        with open(file_name, "w") as f:
            json.dump(corp.model_dump(exclude={"df", "visualization"}), f, indent=4)
        if corp.df is not None and isinstance(corp.df, pd.DataFrame):
            if not corp.df.empty:
                corp.df.to_csv(df_name, index=False)
        logger.info("Corpus written to %s", file_name)

    # @lru_cache(maxsize=3)
    def read_corpus_from_json(self, file_path="", comma_separated_ignore_words=""):
        """
        Read the corpus from a json file. Parallelizes ignore word removal for large corpora.
        """
        from pathlib import Path

        file_path = Path(file_path)
        file_name = file_path / "corpus.json"
        df_name = file_path / "corpus_df.csv"
        if self._source:
            file_name = Path(self._source) / file_name
        if not file_name.exists():
            raise ValueError(f"File not found: {file_name}")
        with open(file_name, "r") as f:
            data = json.load(f)
            self._corpus = Corpus.model_validate(data)
            logger.info(f"Corpus read from {file_name}")
        if df_name.exists():
            self._corpus.df = pd.read_csv(df_name)
        else:
            self._corpus.df = None
        # Remove ignore words from self._corpus.documents text
        documents = self._corpus.documents

        # Pre-compile regex patterns once for efficiency instead of inside loops
        compiled_patterns = []
        if comma_separated_ignore_words:
            for word in comma_separated_ignore_words.split(","):
                pattern = re.compile(r"\b" + word.strip() + r"\b", flags=re.IGNORECASE)
                compiled_patterns.append(pattern)

        if len(documents) < 10:
            processed_docs = []
            for document in tqdm(documents, desc="Processing documents", disable=True):
                for pattern in compiled_patterns:
                    document.text = pattern.sub("", document.text)
                processed_docs.append(document)
        else:

            def process_doc(document):
                for pattern in compiled_patterns:
                    document.text = pattern.sub("", document.text)
                return document

            processed_docs = []
            n_cores = multiprocessing.cpu_count()
            with ThreadPoolExecutor() as executor:
                futures = {
                    executor.submit(process_doc, document): document
                    for document in documents
                }
                with tqdm(
                    total=len(futures),
                    desc=f"Processing documents (parallel, {n_cores} cores)",
                ) as pbar:
                    for future in as_completed(futures):
                        processed_docs.append(future.result())
                        pbar.update(1)
        self._corpus.documents = processed_docs
        return self._corpus

    # @lru_cache(maxsize=3)
    def read_csv_to_corpus(
        self,
        file_name,
        comma_separated_ignore_words=None,
        comma_separated_text_columns="",
        id_column="",
    ):
        """
        Read the corpus from a csv file. Parallelizes document creation for large CSVs.
        """
        from pathlib import Path

        file_name = Path(file_name)
        if not file_name.exists():
            raise ValueError(f"File not found: {file_name}")
        df = pd.read_csv(file_name)
        original_df = df.copy()
        if comma_separated_text_columns:
            text_columns = comma_separated_text_columns.split(",")
        else:
            text_columns = []
        # remove text columns from the dataframe
        for column in text_columns:
            if column in df.columns:
                df.drop(column, axis=1, inplace=True)
        # Set self._df to the numeric part after dropping text columns
        self._df = df.copy()
        rows = list(original_df.iterrows())

        # Pre-compile regex patterns once for efficiency instead of inside loops
        compiled_patterns = []
        if comma_separated_ignore_words:
            for word in comma_separated_ignore_words.split(","):
                pattern = re.compile(r"\b" + word.strip() + r"\b", flags=re.IGNORECASE)
                compiled_patterns.append(pattern)

        def create_document(args):
            index, row = args
            # Use list and join for efficient string concatenation, handle None values
            text_parts = [str(row[column]) if row[column] is not None and not (isinstance(row[column], float) and row[column] != row[column]) else '' for column in text_columns]
            read_from_file = " ".join(text_parts)
            # Apply pre-compiled patterns
            for pattern in compiled_patterns:
                read_from_file = pattern.sub("", read_from_file)
            _document = Document(
                text=read_from_file,
                metadata={
                    "source": str(file_name),
                    "file_name": str(file_name),
                    "row": index,
                    "id": (
                        row[id_column]
                        if (id_column != "" and id_column in original_df.columns)
                        else index
                    ),
                },
                id=str(index),
                score=0.0,
                name="",
                description="",
            )
            return read_from_file, _document

        if len(rows) < 10:
            results = [
                create_document(args)
                for args in tqdm(rows, desc="Reading CSV rows", disable=True)
            ]
        else:

            results = []
            # import multiprocessing

            n_cores = multiprocessing.cpu_count()
            with ThreadPoolExecutor() as executor:
                futures = {
                    executor.submit(create_document, args): args for args in rows
                }
                with tqdm(
                    total=len(futures),
                    desc=f"Reading CSV rows (parallel, {n_cores} cores)",
                ) as pbar:
                    for future in as_completed(futures):
                        results.append(future.result())
                        pbar.update(1)

        if len(results) < 10:
            for read_from_file, _document in tqdm(
                results, desc="Finalizing corpus", disable=True
            ):
                self._documents.append(_document)
        else:

            # import multiprocessing

            n_cores = multiprocessing.cpu_count()
            with tqdm(
                results,
                total=len(results),
                desc=f"Finalizing corpus (parallel, {n_cores} cores)",
            ) as pbar:
                for read_from_file, _document in pbar:
                    self._documents.append(_document)
        logger.info(f"Corpus read from {file_name}")
        self.create_corpus()
        return self._corpus

    def read_source(
        self, source, comma_separated_ignore_words=None, comma_separated_text_columns=""
    ):
        _CSV_EXISTS = False

        # Pre-compile regex patterns once for efficiency instead of inside loops
        compiled_patterns = []
        if comma_separated_ignore_words:
            for word in comma_separated_ignore_words.split(","):
                pattern = re.compile(r"\b" + word.strip() + r"\b", flags=re.IGNORECASE)
                compiled_patterns.append(pattern)

        def apply_ignore_patterns(text):
            """Apply pre-compiled ignore patterns to text."""
            for pattern in compiled_patterns:
                text = pattern.sub("", text)
            return text

        # if source is a url
        if source.startswith("http://") or source.startswith("https://"):
            response = requests.get(source)
            if response.status_code == 200:
                read_from_file = response.text
                read_from_file = apply_ignore_patterns(read_from_file)
                # self._content removed
                _document = Document(
                    text=read_from_file,
                    metadata={"source": source},
                    id=source,
                    score=0.0,
                    name="",
                    description="",
                )
                self._documents.append(_document)
        elif os.path.exists(source):
            source_path = Path(source)
            self._source = source
            logger.info(f"Reading data from folder: {source}")
            file_list = os.listdir(source)
            for file_name in tqdm(
                file_list, desc="Reading files", disable=len(file_list) < 10
            ):
                file_path = source_path / file_name
                if file_name.endswith(".txt"):
                    with open(file_path, "r") as f:
                        read_from_file = f.read()
                        read_from_file = apply_ignore_patterns(read_from_file)
                        # self._content removed
                        _document = Document(
                            text=read_from_file,
                            metadata={
                                "source": str(file_path),
                                "file_name": file_name,
                            },
                            id=file_name,
                            score=0.0,
                            name="",
                            description="",
                        )
                        self._documents.append(_document)
                if file_name.endswith(".pdf"):
                    with open(file_path, "rb") as f:
                        reader = PdfReader(f)
                        # Use list and join for efficient string concatenation
                        page_texts = []
                        for page in tqdm(
                            reader.pages,
                            desc=f"Reading PDF {file_name}",
                            leave=False,
                            disable=len(reader.pages) < 10,
                        ):
                            page_texts.append(page.extract_text())
                        read_from_file = "".join(page_texts)
                        read_from_file = apply_ignore_patterns(read_from_file)
                        # self._content removed
                        _document = Document(
                            text=read_from_file,
                            metadata={
                                "source": str(file_path),
                                "file_name": file_name,
                            },
                            id=file_name,
                            score=0.0,
                            name="",
                            description="",
                        )
                        self._documents.append(_document)
                if file_name.endswith(".csv") and comma_separated_text_columns == "":
                    logger.info(f"Reading CSV file: {file_path}")
                    self._df = Csv().read_csv(file_path)
                    logger.info(f"CSV file read with shape: {self._df.shape}")
                    _CSV_EXISTS = True
                if file_name.endswith(".csv") and comma_separated_text_columns != "":
                    logger.info(f"Reading CSV file to corpus: {file_path}")
                    self.read_csv_to_corpus(
                        file_path,
                        comma_separated_ignore_words,
                        comma_separated_text_columns,
                    )
                    logger.info(
                        f"CSV file read to corpus with documents: {len(self._documents)}"
                    )
                    _CSV_EXISTS = True
            if not _CSV_EXISTS:
                # create a simple csv with columns: id, number, text
                # and fill it with random data
                _csv = """
id,number,response
1,100,Sample text one
2,200,Sample text two
3,300,Sample text three
4,400,Sample text four
"""
                # write the csv to a temp file
                with tempfile.NamedTemporaryFile(
                    mode="w+", delete=False, suffix=".csv"
                ) as temp_csv:
                    temp_csv.write(_csv)
                    temp_csv_path = temp_csv.name
                logger.info(f"No CSV found. Created temp CSV file: {temp_csv_path}")
                self._df = Csv().read_csv(temp_csv_path)
                logger.info(f"CSV file read with shape: {self._df.shape}")
                # remove the temp file
                os.remove(temp_csv_path)

        else:
            raise ValueError(f"Source not found: {source}")

    def corpus_as_dataframe(self):
        """
        Convert the corpus to a pandas dataframe. Parallelizes for large corpora.
        """
        if not self._corpus:
            raise ValueError("No corpus found. Please create a corpus first.")
        documents = self._corpus.documents
        if len(documents) < 10:
            data = [
                document.model_dump()
                for document in tqdm(
                    documents, desc="Converting to dataframe", disable=True
                )
            ]
        else:
            data = []

            def dump_doc(document):
                return document.model_dump()

            n_cores = multiprocessing.cpu_count()
            with ThreadPoolExecutor() as executor:
                futures = {
                    executor.submit(dump_doc, document): document
                    for document in documents
                }
                with tqdm(
                    total=len(futures),
                    desc=f"Converting to dataframe (parallel, {n_cores} cores)",
                ) as pbar:
                    for future in as_completed(futures):
                        data.append(future.result())
                        pbar.update(1)
        df = pd.DataFrame(data)
        return df

corpus property writable

Get the corpus.

df property writable

Get the dataframe.

documents property writable

Get the documents.

corpus_as_dataframe()

Convert the corpus to a pandas dataframe. Parallelizes for large corpora.

Source code in src/crisp_t/read_data.py
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
def corpus_as_dataframe(self):
    """
    Convert the corpus to a pandas dataframe. Parallelizes for large corpora.
    """
    if not self._corpus:
        raise ValueError("No corpus found. Please create a corpus first.")
    documents = self._corpus.documents
    if len(documents) < 10:
        data = [
            document.model_dump()
            for document in tqdm(
                documents, desc="Converting to dataframe", disable=True
            )
        ]
    else:
        data = []

        def dump_doc(document):
            return document.model_dump()

        n_cores = multiprocessing.cpu_count()
        with ThreadPoolExecutor() as executor:
            futures = {
                executor.submit(dump_doc, document): document
                for document in documents
            }
            with tqdm(
                total=len(futures),
                desc=f"Converting to dataframe (parallel, {n_cores} cores)",
            ) as pbar:
                for future in as_completed(futures):
                    data.append(future.result())
                    pbar.update(1)
    df = pd.DataFrame(data)
    return df

create_corpus(name=None, description=None)

Create a corpus from the documents and dataframe.

Source code in src/crisp_t/read_data.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def create_corpus(self, name=None, description=None):
    """
    Create a corpus from the documents and dataframe.
    """
    if not self._documents:
        raise ValueError("No documents found. Please read data first.")
    if self._corpus:
        self._corpus.documents = self._documents
        self._corpus.df = self._df
    else:
        timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        corpus_id = f"corpus_{timestamp}"
        self._corpus = Corpus(
            documents=self._documents,
            df=self._df,
            visualization={},
            metadata={},
            id=corpus_id,
            score=0.0,
            name=name,
            description=description,
        )
    return self._corpus

get_document_by_id(doc_id)

Get a document from the corpus by its ID. Uses parallel search for large corpora.

Source code in src/crisp_t/read_data.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def get_document_by_id(self, doc_id):
    """
    Get a document from the corpus by its ID. Uses parallel search for large corpora.
    """
    if not self._corpus:
        raise ValueError("No corpus found. Please create a corpus first.")
    documents = self._corpus.documents
    if len(documents) < 10:
        for document in tqdm(documents, desc="Searching documents", disable=True):
            if document.id == doc_id:
                return document
    else:
        n_cores = multiprocessing.cpu_count()
        with ThreadPoolExecutor() as executor:
            futures = {
                executor.submit(lambda doc: doc.id == doc_id, document): i
                for i, document in enumerate(documents)
            }
            with tqdm(
                total=len(futures),
                desc=f"Searching documents (parallel, {n_cores} cores)",
            ) as pbar:
                for future in as_completed(futures):
                    i = futures[future]
                    found = future.result()
                    pbar.update(1)
                    if found:
                        return documents[i]
    raise ValueError("Document not found: %s" % doc_id)

get_documents_from_corpus()

Get the documents from the corpus.

Source code in src/crisp_t/read_data.py
156
157
158
159
160
161
162
def get_documents_from_corpus(self):
    """
    Get the documents from the corpus.
    """
    if not self._corpus:
        raise ValueError("No corpus found. Please create a corpus first.")
    return self._corpus.documents

pretty_print()

Pretty print the corpus.

Source code in src/crisp_t/read_data.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def pretty_print(self):
    """
    Pretty print the corpus.
    """
    if not self._corpus:
        self.create_corpus()
    if self._corpus:
        print(
            self._corpus.model_dump_json(indent=4, exclude={"df", "visualization"})
        )
        logger.info(
            "Corpus: %s",
            self._corpus.model_dump_json(indent=4, exclude={"df", "visualization"}),
        )
    else:
        logger.error("No corpus available to pretty print.")

read_corpus_from_json(file_path='', comma_separated_ignore_words='')

Read the corpus from a json file. Parallelizes ignore word removal for large corpora.

Source code in src/crisp_t/read_data.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def read_corpus_from_json(self, file_path="", comma_separated_ignore_words=""):
    """
    Read the corpus from a json file. Parallelizes ignore word removal for large corpora.
    """
    from pathlib import Path

    file_path = Path(file_path)
    file_name = file_path / "corpus.json"
    df_name = file_path / "corpus_df.csv"
    if self._source:
        file_name = Path(self._source) / file_name
    if not file_name.exists():
        raise ValueError(f"File not found: {file_name}")
    with open(file_name, "r") as f:
        data = json.load(f)
        self._corpus = Corpus.model_validate(data)
        logger.info(f"Corpus read from {file_name}")
    if df_name.exists():
        self._corpus.df = pd.read_csv(df_name)
    else:
        self._corpus.df = None
    # Remove ignore words from self._corpus.documents text
    documents = self._corpus.documents

    # Pre-compile regex patterns once for efficiency instead of inside loops
    compiled_patterns = []
    if comma_separated_ignore_words:
        for word in comma_separated_ignore_words.split(","):
            pattern = re.compile(r"\b" + word.strip() + r"\b", flags=re.IGNORECASE)
            compiled_patterns.append(pattern)

    if len(documents) < 10:
        processed_docs = []
        for document in tqdm(documents, desc="Processing documents", disable=True):
            for pattern in compiled_patterns:
                document.text = pattern.sub("", document.text)
            processed_docs.append(document)
    else:

        def process_doc(document):
            for pattern in compiled_patterns:
                document.text = pattern.sub("", document.text)
            return document

        processed_docs = []
        n_cores = multiprocessing.cpu_count()
        with ThreadPoolExecutor() as executor:
            futures = {
                executor.submit(process_doc, document): document
                for document in documents
            }
            with tqdm(
                total=len(futures),
                desc=f"Processing documents (parallel, {n_cores} cores)",
            ) as pbar:
                for future in as_completed(futures):
                    processed_docs.append(future.result())
                    pbar.update(1)
    self._corpus.documents = processed_docs
    return self._corpus

read_csv_to_corpus(file_name, comma_separated_ignore_words=None, comma_separated_text_columns='', id_column='')

Read the corpus from a csv file. Parallelizes document creation for large CSVs.

Source code in src/crisp_t/read_data.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def read_csv_to_corpus(
    self,
    file_name,
    comma_separated_ignore_words=None,
    comma_separated_text_columns="",
    id_column="",
):
    """
    Read the corpus from a csv file. Parallelizes document creation for large CSVs.
    """
    from pathlib import Path

    file_name = Path(file_name)
    if not file_name.exists():
        raise ValueError(f"File not found: {file_name}")
    df = pd.read_csv(file_name)
    original_df = df.copy()
    if comma_separated_text_columns:
        text_columns = comma_separated_text_columns.split(",")
    else:
        text_columns = []
    # remove text columns from the dataframe
    for column in text_columns:
        if column in df.columns:
            df.drop(column, axis=1, inplace=True)
    # Set self._df to the numeric part after dropping text columns
    self._df = df.copy()
    rows = list(original_df.iterrows())

    # Pre-compile regex patterns once for efficiency instead of inside loops
    compiled_patterns = []
    if comma_separated_ignore_words:
        for word in comma_separated_ignore_words.split(","):
            pattern = re.compile(r"\b" + word.strip() + r"\b", flags=re.IGNORECASE)
            compiled_patterns.append(pattern)

    def create_document(args):
        index, row = args
        # Use list and join for efficient string concatenation, handle None values
        text_parts = [str(row[column]) if row[column] is not None and not (isinstance(row[column], float) and row[column] != row[column]) else '' for column in text_columns]
        read_from_file = " ".join(text_parts)
        # Apply pre-compiled patterns
        for pattern in compiled_patterns:
            read_from_file = pattern.sub("", read_from_file)
        _document = Document(
            text=read_from_file,
            metadata={
                "source": str(file_name),
                "file_name": str(file_name),
                "row": index,
                "id": (
                    row[id_column]
                    if (id_column != "" and id_column in original_df.columns)
                    else index
                ),
            },
            id=str(index),
            score=0.0,
            name="",
            description="",
        )
        return read_from_file, _document

    if len(rows) < 10:
        results = [
            create_document(args)
            for args in tqdm(rows, desc="Reading CSV rows", disable=True)
        ]
    else:

        results = []
        # import multiprocessing

        n_cores = multiprocessing.cpu_count()
        with ThreadPoolExecutor() as executor:
            futures = {
                executor.submit(create_document, args): args for args in rows
            }
            with tqdm(
                total=len(futures),
                desc=f"Reading CSV rows (parallel, {n_cores} cores)",
            ) as pbar:
                for future in as_completed(futures):
                    results.append(future.result())
                    pbar.update(1)

    if len(results) < 10:
        for read_from_file, _document in tqdm(
            results, desc="Finalizing corpus", disable=True
        ):
            self._documents.append(_document)
    else:

        # import multiprocessing

        n_cores = multiprocessing.cpu_count()
        with tqdm(
            results,
            total=len(results),
            desc=f"Finalizing corpus (parallel, {n_cores} cores)",
        ) as pbar:
            for read_from_file, _document in pbar:
                self._documents.append(_document)
    logger.info(f"Corpus read from {file_name}")
    self.create_corpus()
    return self._corpus

write_corpus_to_json(file_path='', corpus=None)

Write the corpus to a json file.

Accepts either a directory path or an explicit file path ending with 'corpus.json'. In both cases, a sibling 'corpus_df.csv' will be written next to the json if a DataFrame is available.

Source code in src/crisp_t/read_data.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def write_corpus_to_json(self, file_path="", corpus=None):
    """
    Write the corpus to a json file.

    Accepts either a directory path or an explicit file path ending with
    'corpus.json'. In both cases, a sibling 'corpus_df.csv' will be written
    next to the json if a DataFrame is available.
    """
    from pathlib import Path

    path = Path(file_path)
    # Determine targets
    if path.suffix:  # treat as explicit file path
        file_name = path
        df_name = path.with_name("corpus_df.csv")
    else:
        file_name = path / "corpus.json"
        df_name = path / "corpus_df.csv"

    corp = corpus if corpus is not None else self._corpus
    if not corp:
        raise ValueError("No corpus found. Please create a corpus first.")
    file_name.parent.mkdir(parents=True, exist_ok=True)
    with open(file_name, "w") as f:
        json.dump(corp.model_dump(exclude={"df", "visualization"}), f, indent=4)
    if corp.df is not None and isinstance(corp.df, pd.DataFrame):
        if not corp.df.empty:
            corp.df.to_csv(df_name, index=False)
    logger.info("Corpus written to %s", file_name)

main(verbose, id, name, description, docs, remove_docs, metas, relationships, clear_rel, print_corpus, out, inp, df_cols, df_row_count, df_row, doc_ids, doc_id, print_relationships, relationships_for_keyword, semantic, similar_docs, num, semantic_chunks, rec, metadata_df, metadata_keys, tdabm, graph)

CRISP-T Corpus CLI: create and manipulate a corpus quickly from the command line.

Source code in src/crisp_t/corpuscli.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
@click.command()
@click.option("--verbose", "-v", is_flag=True, help="Print verbose messages.")
@click.option("--id", help="Unique identifier for the corpus.")
@click.option("--name", default=None, help="Name of the corpus.")
@click.option("--description", default=None, help="Description of the corpus.")
@click.option(
    "--doc",
    "docs",
    multiple=True,
    help=(
        "Add a document as 'id|name|text' (or 'id|text' if name omitted). "
        "Can be used multiple times."
    ),
)
@click.option(
    "--remove-doc",
    "remove_docs",
    multiple=True,
    help="Remove a document by its ID (can be used multiple times).",
)
@click.option(
    "--meta",
    "metas",
    multiple=True,
    help="Add or update corpus metadata as key=value (can be used multiple times).",
)
@click.option(
    "--add-rel",
    "relationships",
    multiple=True,
    help=(
        "Add a relationship as 'first|second|relation' (e.g., text:term|numb:col|correlates)."
    ),
)
@click.option(
    "--clear-rel",
    is_flag=True,
    help="Clear all relationships in the corpus metadata.",
)
@click.option("--print", "print_corpus", is_flag=True, help="Pretty print the corpus")
@click.option(
    "--out", default=None, help="Write corpus to a folder or file as corpus.json (save)"
)
@click.option(
    "--inp",
    default=None,
    help="Load corpus from a folder or file containing corpus.json (load)",
)
# New options for Corpus methods
@click.option("--df-cols", is_flag=True, help="Print all DataFrame column names.")
@click.option("--df-row-count", is_flag=True, help="Print number of rows in DataFrame.")
@click.option("--df-row", default=None, type=int, help="Print DataFrame row by index.")
@click.option("--doc-ids", is_flag=True, help="Print all document IDs in the corpus.")
@click.option("--doc-id", default=None, help="Print document by ID.")
@click.option(
    "--relationships",
    "print_relationships",
    is_flag=True,
    help="Print all relationships in the corpus.",
)
@click.option(
    "--relationships-for-keyword",
    default=None,
    help="Print all relationships involving a specific keyword.",
)
@click.option(
    "--semantic",
    default=None,
    help="Perform semantic search with the given query string. Returns similar documents.",
)
@click.option(
    "--similar-docs",
    default=None,
    help="Find documents similar to a comma-separated list of document IDs. Use with --num and --rec. Useful for literature reviews.",
)
@click.option(
    "--num",
    default=5,
    type=int,
    help="Number of results to return (default: 5). Used for semantic search and similar documents search.",
)
@click.option(
    "--semantic-chunks",
    default=None,
    help="Perform semantic search on document chunks. Returns matching chunks for a specific document. Use with --doc-id and --rec (threshold).",
)
@click.option(
    "--rec",
    default=0.4,
    type=float,
    help="Threshold for semantic search (0-1, default: 0.4). Only chunks with similarity above this value are returned.",
)
@click.option(
    "--metadata-df",
    is_flag=True,
    help="Export collection metadata as DataFrame. Requires semantic search to be initialized first.",
)
@click.option(
    "--metadata-keys",
    default=None,
    help="Comma-separated list of metadata keys to include in DataFrame export.",
)
@click.option(
    "--tdabm",
    default=None,
    help="Perform TDABM analysis. Format: 'y_variable:x_variables:radius' (e.g., 'satisfaction:age,income:0.3'). Radius defaults to 0.3 if omitted.",
)
@click.option(
    "--graph",
    is_flag=True,
    help="Generate graph representation of the corpus. Requires documents to have keywords assigned (run with --assign first).",
)
def main(
    verbose: bool,
    id: Optional[str],
    name: Optional[str],
    description: Optional[str],
    docs: tuple[str, ...],
    remove_docs: tuple[str, ...],
    metas: tuple[str, ...],
    relationships: tuple[str, ...],
    clear_rel: bool,
    print_corpus: bool,
    out: Optional[str],
    inp: Optional[str],
    df_cols: bool,
    df_row_count: bool,
    df_row: Optional[int],
    doc_ids: bool,
    doc_id: Optional[str],
    print_relationships: bool,
    relationships_for_keyword: Optional[str],
    semantic: Optional[str],
    similar_docs: Optional[str],
    num: int,
    semantic_chunks: Optional[str],
    rec: float,
    metadata_df: bool,
    metadata_keys: Optional[str],
    tdabm: Optional[str],
    graph: bool,
):
    """
    CRISP-T Corpus CLI: create and manipulate a corpus quickly from the command line.
    """
    logging.basicConfig(level=(logging.DEBUG if verbose else logging.WARNING))
    logger = logging.getLogger(__name__)

    if verbose:
        click.echo("Verbose mode enabled")

    click.echo("_________________________________________")
    click.echo("CRISP-T: Corpus CLI")
    click.echo("_________________________________________")

    # Load corpus from --inp if provided
    corpus = initialize_corpus(inp=inp)
    if not corpus:
        # Build initial corpus from CLI args
        if not id:
            raise click.ClickException("--id is required when not using --inp.")
        corpus = Corpus(
            id=id,
            name=name,
            description=description,
            score=None,
            documents=[],
            df=None,
            visualization={},
            metadata={},
        )

    # Add documents
    for d in docs:
        doc_id, doc_name, doc_text = _parse_doc(d)
        document = Document(
            id=doc_id,
            name=doc_name,
            description=None,
            score=0.0,
            text=doc_text,
            metadata={},
        )
        corpus.add_document(document)
    if docs:
        click.echo(f"✓ Added {len(docs)} document(s)")

    # Remove documents
    for rid in remove_docs:
        corpus.remove_document_by_id(rid)
    if remove_docs:
        click.echo(f"✓ Removed {len(remove_docs)} document(s)")

    # Update metadata
    for m in metas:
        k, v = _parse_kv(m)
        corpus.update_metadata(k, v)
    if metas:
        click.echo(f"✓ Updated metadata entries: {len(metas)}")

    # Relationships
    for r in relationships:
        first, second, relation = _parse_relationship(r)
        corpus.add_relationship(first, second, relation)
    if relationships:
        click.echo(f"✓ Added {len(relationships)} relationship(s)")
    if clear_rel:
        corpus.clear_relationships()
        click.echo("✓ Cleared relationships")

    # Print DataFrame column names
    if df_cols:
        cols = corpus.get_all_df_column_names()
        click.echo(f"DataFrame columns: {cols}")

    # Print DataFrame row count
    if df_row_count:
        count = corpus.get_row_count()
        click.echo(f"DataFrame row count: {count}")

    # Print DataFrame row by index
    if df_row is not None:
        row = corpus.get_row_by_index(df_row)
        if row is not None:
            click.echo(f"DataFrame row {df_row}: {row.to_dict()}")
        else:
            click.echo(f"No row at index {df_row}")

    # Print all document IDs
    if doc_ids:
        ids = corpus.get_all_document_ids()
        click.echo(f"Document IDs: {ids}")

    # Print document by ID
    if doc_id:
        doc = corpus.get_document_by_id(doc_id)
        if doc:
            click.echo(f"Document {doc_id}: {doc.model_dump()}")
        else:
            click.echo(f"No document found with ID {doc_id}")
            exit(0)

    # Print relationships
    if print_relationships:
        rels = corpus.get_relationships()
        click.echo(f"Relationships: {rels}")

    # Print relationships for keyword
    if relationships_for_keyword:
        rels = corpus.get_all_relationships_for_keyword(relationships_for_keyword)
        click.echo(f"Relationships for keyword '{relationships_for_keyword}': {rels}")

    # Semantic search
    if semantic:
        try:
            from .semantic import Semantic

            click.echo(f"\nPerforming semantic search for: '{semantic}'")
            # Try with default embeddings first, fall back to simple embeddings
            try:
                semantic_analyzer = Semantic(corpus)
            except Exception as network_error:
                # If network error or download fails, try simple embeddings
                if "address" in str(network_error).lower() or "download" in str(network_error).lower():
                    click.echo("Note: Using simple embeddings (network unavailable)")
                    semantic_analyzer = Semantic(corpus, use_simple_embeddings=True)
                else:
                    raise
            corpus = semantic_analyzer.get_similar(semantic, n_results=num)
            click.echo(f"✓ Found {len(corpus.documents)} similar documents")
            click.echo(
                f"Hint: Use --out to save the filtered corpus, or --print to view results"
            )
        except ImportError as e:
            click.echo(f"Error: {e}")
            click.echo("Install chromadb with: pip install chromadb")
        except Exception as e:
            click.echo(f"Error during semantic search: {e}")

    # Find similar documents
    if similar_docs:
        try:
            from .semantic import Semantic

            click.echo(f"\nFinding documents similar to: '{similar_docs}'")
            click.echo(f"Number of results: {num}")
            # Convert rec to 0-1 range if needed (for similar_docs, threshold is 0-1)
            threshold = rec / 10.0 if rec > 1.0 else rec
            click.echo(f"Similarity threshold: {threshold}")

            # Try with default embeddings first, fall back to simple embeddings
            try:
                semantic_analyzer = Semantic(corpus)
            except Exception as network_error:
                # If network error or download fails, try simple embeddings
                if "address" in str(network_error).lower() or "download" in str(network_error).lower():
                    click.echo("Note: Using simple embeddings (network unavailable)")
                    semantic_analyzer = Semantic(corpus, use_simple_embeddings=True)
                else:
                    raise

            # Get similar document IDs
            similar_doc_ids = semantic_analyzer.get_similar_documents(
                document_ids=similar_docs,
                n_results=num,
                threshold=threshold
            )

            click.echo(f"✓ Found {len(similar_doc_ids)} similar documents")
            if similar_doc_ids:
                click.echo("\nSimilar Document IDs:")
                for doc_id in similar_doc_ids:
                    doc = corpus.get_document_by_id(doc_id)
                    doc_name = f" ({doc.name})" if doc and doc.name else ""
                    click.echo(f"  - {doc_id}{doc_name}")
                click.echo("\nHint: Use --doc-id to view individual documents")
                click.echo("Hint: This feature is useful for literature reviews to find similar documents")
            else:
                click.echo("No similar documents found above the threshold.")
                click.echo("Hint: Try lowering the threshold with --rec")

        except ImportError as e:
            click.echo(f"Error: {e}")
            click.echo("Install chromadb with: pip install chromadb")
        except Exception as e:
            click.echo(f"Error finding similar documents: {e}")


    # Semantic chunk search
    if semantic_chunks:
        if not doc_id:
            click.echo("Error: --doc-id is required when using --semantic-chunks")
        else:
            try:
                from .semantic import Semantic

                click.echo(f"\nPerforming semantic chunk search for: '{semantic_chunks}'")
                click.echo(f"Document ID: {doc_id}")
                click.echo(f"Threshold: {rec}")

                # Try with default embeddings first, fall back to simple embeddings
                try:
                    semantic_analyzer = Semantic(corpus)
                except Exception as network_error:
                    # If network error or download fails, try simple embeddings
                    if "address" in str(network_error).lower() or "download" in str(network_error).lower():
                        click.echo("Note: Using simple embeddings (network unavailable)")
                        semantic_analyzer = Semantic(corpus, use_simple_embeddings=True)
                    else:
                        raise

                # Get similar chunks
                chunks = semantic_analyzer.get_similar_chunks(
                    query=semantic_chunks,
                    doc_id=doc_id,
                    threshold=rec,
                    n_results=20  # Get more chunks to filter by threshold
                )

                click.echo(f"✓ Found {len(chunks)} matching chunks")
                click.echo("\nMatching chunks:")
                click.echo("=" * 60)
                for i, chunk in enumerate(chunks, 1):
                    click.echo(f"\nChunk {i}:")
                    click.echo(chunk)
                    click.echo("-" * 60)

                if len(chunks) == 0:
                    click.echo("No chunks matched the query above the threshold.")
                    click.echo("Hint: Try lowering the threshold with --rec or use a different query.")
                else:
                    click.echo(f"\nHint: These {len(chunks)} chunks can be used for coding/annotating the document.")
                    click.echo("Hint: Adjust --rec threshold to get more or fewer results.")

            except ImportError as e:
                click.echo(f"Error: {e}")
                click.echo("Install chromadb with: pip install chromadb")
            except Exception as e:
                click.echo(f"Error during semantic chunk search: {e}")

    # Export metadata as DataFrame
    if metadata_df:
        try:
            from .semantic import Semantic

            click.echo("\nExporting metadata as DataFrame...")
            # Try with default embeddings first, fall back to simple embeddings
            try:
                semantic_analyzer = Semantic(corpus)
            except Exception as network_error:
                # If network error or download fails, try simple embeddings
                if "address" in str(network_error).lower() or "download" in str(network_error).lower():
                    click.echo("Note: Using simple embeddings (network unavailable)")
                    semantic_analyzer = Semantic(corpus, use_simple_embeddings=True)
                else:
                    raise
            # Parse metadata_keys if provided
            keys_list = None
            if metadata_keys:
                keys_list = [k.strip() for k in metadata_keys.split(",")]
            corpus = semantic_analyzer.get_df(metadata_keys=keys_list)
            click.echo("✓ Metadata exported to DataFrame")
            if corpus.df is not None:
                click.echo(f"DataFrame shape: {corpus.df.shape}")
                click.echo(f"Columns: {list(corpus.df.columns)}")
            click.echo("Hint: Use --out to save the corpus with the updated DataFrame")
        except ImportError as e:
            click.echo(f"Error: {e}")
            click.echo("Install chromadb with: pip install chromadb")
        except Exception as e:
            click.echo(f"Error exporting metadata: {e}")

    # TDABM analysis
    if tdabm:
        try:
            # Parse tdabm parameter: y_variable:x_variables:radius
            parts = tdabm.split(":")
            if len(parts) < 2:
                raise click.ClickException(
                    "Invalid --tdabm format. Use 'y_variable:x_variables:radius' "
                    "(e.g., 'satisfaction:age,income:0.3'). Radius defaults to 0.3 if omitted."
                )

            y_var = parts[0].strip()
            x_vars = parts[1].strip()
            radius = 0.3  # default

            if len(parts) >= 3:
                try:
                    radius = float(parts[2].strip())
                except ValueError:
                    raise click.ClickException(f"Invalid radius value: '{parts[2]}'. Must be a number.")

            click.echo(f"\nPerforming TDABM analysis...")
            click.echo(f"  Y variable: {y_var}")
            click.echo(f"  X variables: {x_vars}")
            click.echo(f"  Radius: {radius}")

            tdabm_analyzer = Tdabm(corpus)
            result = tdabm_analyzer.generate_tdabm(y=y_var, x_variables=x_vars, radius=radius)

            click.echo("\n" + result)
            click.echo("\nHint: TDABM results stored in corpus metadata['tdabm']")
            click.echo("Hint: Use --out to save the corpus with TDABM metadata")
            click.echo("Hint: Use 'crispviz --tdabm' to visualize the results")

        except ValueError as e:
            click.echo(f"Error: {e}")
            click.echo("Hint: Ensure your corpus has a DataFrame with the specified variables")
            click.echo("Hint: Y variable must be continuous (not binary)")
            click.echo("Hint: X variables must be numeric/ordinal")
        except Exception as e:
            click.echo(f"Error during TDABM analysis: {e}")

    # Graph generation
    if graph:
        try:
            from .graph import CrispGraph

            click.echo("\nGenerating graph representation...")
            graph_gen = CrispGraph(corpus)
            graph_data = graph_gen.create_graph()

            click.echo(f"✓ Graph created successfully")
            click.echo(f"  Nodes: {graph_data['num_nodes']}")
            click.echo(f"  Edges: {graph_data['num_edges']}")
            click.echo(f"  Documents: {graph_data['num_documents']}")
            click.echo(f"  Has keywords: {graph_data['has_keywords']}")
            click.echo(f"  Has clusters: {graph_data['has_clusters']}")
            click.echo(f"  Has metadata: {graph_data['has_metadata']}")

            click.echo("\nHint: Graph data stored in corpus metadata['graph']")
            click.echo("Hint: Use --out to save the corpus with graph metadata")
            click.echo("Hint: Use 'crispviz --graph' to visualize the graph")

        except ValueError as e:
            click.echo(f"Error: {e}")
            click.echo("Hint: Make sure documents have keywords assigned first")
            click.echo("Hint: You can assign keywords using text analysis features")
        except Exception as e:
            click.echo(f"Error generating graph: {e}")
            logger.error(f"Graph generation error: {e}", exc_info=True)

    # Save corpus to --out if provided
    if out:
        from .read_data import ReadData

        rd = ReadData(corpus=corpus)
        rd.write_corpus_to_json(out, corpus=corpus)
        click.echo(f"✓ Corpus saved to {out}")

    if print_corpus:
        click.echo("\n=== Corpus Details ===")
        corpus.pretty_print()

    logger.info("Corpus CLI finished")

main(verbose, inp, out, bins, topics_num, top_n, corr_columns, freq, by_topic, wordcloud, ldavis, top_terms, corr_heatmap, tdabm, graph, graph_nodes, graph_layout)

CRISP-T: Visualization CLI

Build corpus (source preferred over inp), optionally handle multiple sources, and export selected visualizations as PNG files into the output directory.

Source code in src/crisp_t/vizcli.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
@click.command()
@click.option("--verbose", "-v", is_flag=True, help="Print verbose messages.")
@click.option("--inp", "-i", help="Load corpus from a folder containing corpus.json")
@click.option(
    "--out",
    "-o",
    help="Output directory where PNG images will be written",
)
@click.option(
    "--bins", default=100, show_default=True, help="Number of bins for distributions"
)
@click.option(
    "--topics-num",
    default=8,
    show_default=True,
    help="Number of topics for LDA when required (default 8 as per Mettler et al. 2025)",
)
@click.option(
    "--top-n",
    default=20,
    show_default=True,
    help="Top N terms to show in top-terms chart",
)
@click.option(
    "--corr-columns",
    default="",
    help="Comma separated numeric columns for correlation heatmap; if empty, auto-select",
)
@click.option("--freq", is_flag=True, help="Export: word frequency distribution")
@click.option(
    "--by-topic",
    is_flag=True,
    help="Export: distribution by dominant topic (requires LDA)",
)
@click.option(
    "--wordcloud", is_flag=True, help="Export: topic wordcloud (requires LDA)"
)
@click.option(
    "--ldavis",
    is_flag=True,
    help="Export: interactive LDA visualization HTML (requires LDA)",
)
@click.option(
    "--top-terms", is_flag=True, help="Export: top terms bar chart (computed from text)"
)
@click.option(
    "--corr-heatmap",
    is_flag=True,
    help="Export: correlation heatmap (from CSV numeric columns)",
)
@click.option(
    "--tdabm",
    is_flag=True,
    help="Export: TDABM visualization (requires TDABM analysis in corpus metadata)",
)
@click.option(
    "--graph",
    is_flag=True,
    help="Export: graph visualization (requires graph data in corpus metadata)",
)
@click.option(
    "--graph-nodes",
    default="",
    help=(
        "Comma separated node types to include for graph: document,keyword,cluster,metadata. "
        "Example: --graph-nodes document,keyword. If empty or 'all', include all."
    ),
)
@click.option(
    "--graph-layout",
    default="spring",
    show_default=True,
    help="Layout algorithm for graph visualization: spring, circular, kamada_kawai, or spectral",
)
def main(
    verbose: bool,
    inp: Optional[str],
    out: str,
    bins: int,
    topics_num: int,
    top_n: int,
    corr_columns: str,
    freq: bool,
    by_topic: bool,
    wordcloud: bool,
    ldavis: bool,
    top_terms: bool,
    corr_heatmap: bool,
    tdabm: bool,
    graph: bool,
    graph_nodes: str,
    graph_layout: str,
):
    """CRISP-T: Visualization CLI

    Build corpus (source preferred over inp), optionally handle multiple sources,
    and export selected visualizations as PNG files into the output directory.
    """

    if verbose:
        logging.getLogger().setLevel(logging.DEBUG)
        click.echo("Verbose mode enabled")

    click.echo("_________________________________________")
    click.echo("CRISP-T: Visualizations")
    click.echo(f"Version: {__version__}")
    click.echo("_________________________________________")

    try:
        out_dir = Path(out)
    except TypeError:
        click.echo(
            f"No output directory specified. Visualizations need an output folder."
        )
        raise click.Abort()
    out_dir.mkdir(parents=True, exist_ok=True)

    # Initialize components
    read_data = ReadData()
    corpus = None

    corpus = initialize_corpus(inp=inp)

    if not corpus:
        raise click.ClickException("No input provided. Use --source/--sources or --inp")

    viz = QRVisualize(corpus=corpus)

    # Helper: build LDA if by-topic or wordcloud requested
    cluster_instance = None

    def ensure_topics():
        nonlocal cluster_instance
        if cluster_instance is None:
            cluster_instance = Cluster(corpus=corpus)
            cluster_instance.build_lda_model(topics=topics_num)
            # Populate visualization structures used by QRVisualize
            cluster_instance.format_topics_sentences(visualize=True)
        return cluster_instance

    # 1) Word frequency distribution
    if freq:
        df_text = pd.DataFrame(
            {"Text": [getattr(doc, "text", "") or "" for doc in corpus.documents]}
        )
        out_path = out_dir / "word_frequency.png"
        viz.plot_frequency_distribution_of_words(
            df=df_text, folder_path=str(out_path), bins=bins, show=False
        )
        click.echo(f"Saved: {out_path}")

    # 2) Distribution by topic (requires topics)
    if by_topic:
        ensure_topics()
        out_path = out_dir / "by_topic.png"
        viz.plot_distribution_by_topic(
            df=None, folder_path=str(out_path), bins=bins, show=False
        )
        click.echo(f"Saved: {out_path}")

    # 3) Topic wordcloud (requires topics)
    if wordcloud:
        ensure_topics()
        out_path = out_dir / "wordcloud.png"
        viz.plot_wordcloud(topics=None, folder_path=str(out_path), show=False)
        click.echo(f"Saved: {out_path}")

    # 3.5) LDA visualization (requires topics)
    if ldavis:
        cluster = ensure_topics()
        out_path = out_dir / "lda_visualization.html"
        try:
            viz.get_lda_viz(
                lda_model=cluster._lda_model,
                corpus_bow=cluster._bag_of_words,
                dictionary=cluster._dictionary,
                folder_path=str(out_path),
                show=False,
            )
            click.echo(f"Saved: {out_path}")
        except ImportError as e:
            click.echo(f"Warning: {e}")
        except Exception as e:
            click.echo(f"Error generating LDA visualization: {e}")

    # 4) Top terms (compute from text directly)
    if top_terms:
        texts = [getattr(doc, "text", "") or "" for doc in corpus.documents]
        tokens = []
        for t in texts:
            tokens.extend((t or "").lower().split())
        freq_map = Counter(tokens)
        if not freq_map:
            click.echo("No tokens found to plot top terms.")
        else:
            df_terms = pd.DataFrame(
                {
                    "term": list(freq_map.keys()),
                    "frequency": list(freq_map.values()),
                }
            )
            # QRVisualize sorts internally; we just pass full DF
            out_path = out_dir / "top_terms.png"
            viz.plot_top_terms(
                df=df_terms, top_n=top_n, folder_path=str(out_path), show=False
            )
            click.echo(f"Saved: {out_path}")

    # 5) Correlation heatmap
    if corr_heatmap:
        if getattr(corpus, "df", None) is None or corpus.df.empty:
            click.echo("No CSV data available for correlation heatmap; skipping.")
        else:
            df0 = corpus.df.copy()
            # If user specified columns, attempt to use them; else let visualize auto-select
            cols = (
                [c.strip() for c in corr_columns.split(",") if c.strip()]
                if corr_columns
                else None
            )
            out_path = out_dir / "corr_heatmap.png"
            if cols:
                # Pass subset to avoid rename ambiguity
                sub = (
                    df0[cols].copy().select_dtypes(include=["number"])
                )  # ensure numeric
                viz.plot_correlation_heatmap(
                    df=sub, columns=None, folder_path=str(out_path), show=False
                )
            else:
                viz.plot_correlation_heatmap(
                    df=df0, columns=None, folder_path=str(out_path), show=False
                )
            click.echo(f"Saved: {out_path}")

    # TDABM visualization
    if tdabm:
        if "tdabm" not in corpus.metadata:
            click.echo("Warning: No TDABM data found in corpus metadata.")
            click.echo(
                "Hint: Run TDABM analysis first with: crispt --tdabm y_var:x_vars:radius --inp <corpus_dir>"
            )
        else:
            out_path = out_dir / "tdabm.png"
            try:
                viz.draw_tdabm(corpus=corpus, folder_path=str(out_path), show=False)
                click.echo(f"Saved: {out_path}")
            except Exception as e:
                click.echo(f"Error generating TDABM visualization: {e}")
                logger.error(f"TDABM visualization error: {e}", exc_info=True)

    # Graph visualization (filtered by node types if provided)
    if graph or graph_nodes:
        if "graph" not in corpus.metadata:
            click.echo("Warning: No graph data found in corpus metadata.")
            click.echo(
                "Hint: Run graph generation first with: crispt --graph --inp <corpus_dir>"
            )
        else:
            raw_types = (graph_nodes or "").strip().lower()
            include_all = raw_types in ("", "all", "*")
            allowed_types = {"document", "keyword", "cluster", "metadata"}
            requested_types = set()
            if not include_all:
                for part in raw_types.split(","):
                    p = part.strip()
                    if not p:
                        continue
                    if p in allowed_types:
                        requested_types.add(p)
                    else:
                        click.echo(
                            f"Warning: Unknown node type '{p}' ignored. Allowed: {', '.join(sorted(allowed_types))}"
                        )
                if not requested_types:
                    click.echo("No valid node types specified; defaulting to all.")
                    include_all = True

            graph_data = corpus.metadata.get("graph", {})
            nodes = graph_data.get("nodes", [])
            edges = graph_data.get("edges", [])

            if include_all:
                filtered_nodes = nodes
                filtered_edges = edges
            else:
                filtered_nodes = [n for n in nodes if n.get("label") in requested_types]
                kept_ids = {str(n.get("id")) for n in filtered_nodes}
                filtered_edges = [
                    e
                    for e in edges
                    if str(e.get("source")) in kept_ids
                    and str(e.get("target")) in kept_ids
                ]

            # Build a shallow copy of graph metadata with filtered components
            filtered_graph_meta = dict(graph_data)
            filtered_graph_meta["nodes"] = filtered_nodes
            filtered_graph_meta["edges"] = filtered_edges
            filtered_graph_meta["num_nodes"] = len(filtered_nodes)
            filtered_graph_meta["num_edges"] = len(filtered_edges)
            filtered_graph_meta["num_documents"] = sum(
                1 for n in filtered_nodes if n.get("label") == "document"
            )

            # Inject temporary filtered metadata for visualization
            original_graph_meta = corpus.metadata.get("graph")
            corpus.metadata["graph"] = filtered_graph_meta
            out_path = out_dir / "graph.png"
            try:
                viz.draw_graph(
                    corpus=corpus,
                    folder_path=str(out_path),
                    show=False,
                    layout=graph_layout,
                )
                click.echo(f"Saved: {out_path}")
                if not include_all:
                    click.echo(
                        f"Graph filtered to node types: {', '.join(sorted(requested_types))}"
                    )
            except Exception as e:
                click.echo(f"Error generating graph visualization: {e}")
                logger.error(f"Graph visualization error: {e}", exc_info=True)
            finally:
                # Restore original metadata (avoid side-effects)
                corpus.metadata["graph"] = original_graph_meta

    click.echo("\n=== Visualization Complete ===")

Copyright (C) 2025 Bell Eapen

This file is part of crisp-t.

crisp-t is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

crisp-t is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with crisp-t. If not, see https://www.gnu.org/licenses/.

Corpus

Bases: BaseModel

Corpus model for storing a collection of documents.

Source code in src/crisp_t/model/corpus.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
class Corpus(BaseModel):
    """
    Corpus model for storing a collection of documents.
    """

    id: str = Field(..., description="Unique identifier for the corpus.")
    name: Optional[str] = Field(None, description="Name of the corpus.")
    description: Optional[str] = Field(None, description="Description of the corpus.")
    score: Optional[float] = Field(
        None, description="Score associated with the corpus."
    )
    documents: list[Document] = Field(
        default_factory=list, description="List of documents in the corpus."
    )
    df: Optional[pd.DataFrame] = Field(
        None, description="Numeric data associated with the corpus."
    )
    visualization: Dict[str, Any] = Field(
        default_factory=dict, description="Visualization data associated with the corpus."
    )
    model_config = ConfigDict(
        arbitrary_types_allowed=True
    )  # required for pandas DataFrame
    metadata: dict = Field(
        default_factory=dict, description="Metadata associated with the corpus."
    )

    def pretty_print(self, show="all"):
        """
        Print the corpus information in a human-readable format.

        Args:
            show: Display option. Can be:
                - "all": Show all corpus information
                - "documents": Show first 5 documents
                - "documents N": Show first N documents (e.g., "documents 10")
                - "documents metadata": Show document-specific metadata
                - "dataframe": Show DataFrame head
                - "dataframe metadata": Show DataFrame metadata columns (metadata_*)
                - "dataframe stats": Show DataFrame statistics
                - "metadata": Show all corpus metadata
                - "metadata KEY": Show specific metadata key (e.g., "metadata pca")
                - "stats": Show DataFrame statistics (deprecated, use "dataframe stats")
        """
        # Color codes for terminal output
        BLUE = "\033[94m"
        GREEN = "\033[92m"
        YELLOW = "\033[93m"
        CYAN = "\033[96m"
        MAGENTA = "\033[95m"
        RED = "\033[91m"
        RESET = "\033[0m"
        BOLD = "\033[1m"

        # Parse the show parameter to support subcommands
        parts = show.split(maxsplit=1)
        main_command = parts[0]
        sub_command = parts[1] if len(parts) > 1 else None

        # Print basic corpus info for most commands
        if main_command in ["all", "documents", "dataframe", "metadata"]:
            print(f"{BOLD}{BLUE}Corpus ID:{RESET} {self.id}")
            print(f"{BOLD}{BLUE}Name:{RESET} {self.name}")
            print(f"{BOLD}{BLUE}Description:{RESET} {self.description}")

        # Handle documents command
        if main_command in ["all", "documents"]:
            if sub_command == "metadata":
                # Show document-specific metadata
                print(f"\n{BOLD}{GREEN}=== Document Metadata ==={RESET}")
                if not self.documents:
                    print("No documents in corpus")
                else:
                    for i, doc in enumerate(self.documents, 1):
                        print(f"\n{CYAN}Document {i}:{RESET}")
                        print(f"  {BOLD}ID:{RESET} {doc.id}")
                        print(f"  {BOLD}Name:{RESET} {doc.name}")
                        if doc.metadata:
                            print(f"  {BOLD}Metadata:{RESET}")
                            for key, value in doc.metadata.items():
                                # Truncate long values
                                val_str = str(value)
                                if len(val_str) > 100:
                                    val_str = val_str[:97] + "..."
                                print(f"    {YELLOW}{key}:{RESET} {val_str}")
                        else:
                            print(f"  {BOLD}Metadata:{RESET} (none)")
            else:
                # Determine how many documents to show
                num_docs = 5  # default
                if sub_command:
                    try:
                        num_docs = int(sub_command)
                    except ValueError:
                        print(f"{RED}Invalid number for documents: {sub_command}. Using default (5).{RESET}")

                print(f"\n{BOLD}{GREEN}=== Documents ==={RESET}")
                print(f"Total documents: {len(self.documents)}")
                print(f"Showing first {min(num_docs, len(self.documents))} document(s):\n")

                for i, doc in enumerate(self.documents[:num_docs], 1):
                    print(f"{CYAN}Document {i}:{RESET}")
                    print(f"  {BOLD}Name:{RESET} {doc.name}")
                    print(f"  {BOLD}ID:{RESET} {doc.id}")
                    # Show a snippet of text if available
                    if hasattr(doc, 'text') and doc.text:
                        text_snippet = doc.text[:200] + "..." if len(doc.text) > 200 else doc.text
                        print(f"  {BOLD}Text:{RESET} {text_snippet}")
                    print()

        # Handle dataframe command
        if main_command in ["all", "dataframe"]:
            if self.df is not None:
                if sub_command == "metadata":
                    # Show DataFrame metadata columns (columns starting with metadata_)
                    print(f"\n{BOLD}{GREEN}=== DataFrame Metadata Columns ==={RESET}")
                    metadata_cols = [col for col in self.df.columns if col.startswith("metadata_")]
                    if metadata_cols:
                        print(f"Found {len(metadata_cols)} metadata column(s):")
                        for col in metadata_cols:
                            print(f"  {YELLOW}{col}{RESET}")
                            # Show some statistics for the metadata column
                            print(f"    Non-null values: {self.df[col].notna().sum()}")
                            print(f"    Null values: {self.df[col].isna().sum()}")
                            # Show unique values if not too many
                            unique_count = self.df[col].nunique()
                            if unique_count <= 10:
                                print(f"    Unique values ({unique_count}): {list(self.df[col].unique())}")
                            else:
                                print(f"    Unique values: {unique_count}")
                    else:
                        print("No metadata columns found (columns starting with 'metadata_')")
                elif sub_command == "stats":
                    # Show DataFrame statistics
                    self._print_dataframe_stats()
                else:
                    # Show DataFrame head
                    print(f"\n{BOLD}{GREEN}=== DataFrame ==={RESET}")
                    print(f"Shape: {self.df.shape}")
                    print(f"Columns: {list(self.df.columns)}")
                    print("\nFirst few rows:")
                    print(self.df.head())
            else:
                if main_command == "dataframe":
                    print(f"\n{BOLD}{RED}No DataFrame available{RESET}")

        # Handle metadata command
        if main_command in ["all", "metadata"]:
            if sub_command:
                # Show specific metadata key
                print(f"\n{BOLD}{GREEN}=== Metadata: {sub_command} ==={RESET}")
                if sub_command in self.metadata:
                    value = self.metadata[sub_command]
                    # Format the output based on the type of value
                    if isinstance(value, dict):
                        for k, v in value.items():
                            print(f"{YELLOW}{k}:{RESET} {v}")
                    elif isinstance(value, list):
                        for i, item in enumerate(value, 1):
                            print(f"{i}. {item}")
                    else:
                        print(value)
                else:
                    print(f"{RED}Metadata key '{sub_command}' not found{RESET}")
                    available_keys = list(self.metadata.keys())
                    if available_keys:
                        print(f"Available keys: {', '.join(available_keys)}")
            else:
                # Show all metadata
                print(f"\n{BOLD}{GREEN}=== Corpus Metadata ==={RESET}")
                if not self.metadata:
                    print("No metadata available")
                else:
                    for key, value in self.metadata.items():
                        print(f"\n{MAGENTA}{key}:{RESET}")
                        # Truncate long values
                        val_str = str(value)
                        if len(val_str) > 500:
                            val_str = val_str[:497] + "..."
                        print(f"  {val_str}")

        # Handle stats command (deprecated, redirect to dataframe stats)
        if main_command == "stats":
            print(f"{YELLOW}Note: 'stats' is deprecated. Use 'dataframe stats' instead.{RESET}")
            if self.df is not None:
                self._print_dataframe_stats()
            else:
                print(f"{RED}No DataFrame available{RESET}")

        print(f"\n{BOLD}Display completed for '{show}'{RESET}")

    def _print_dataframe_stats(self):
        """Helper method to print DataFrame statistics."""
        YELLOW = "\033[93m"
        BOLD = "\033[1m"
        RESET = "\033[0m"
        GREEN = "\033[92m"

        print(f"\n{BOLD}{GREEN}=== DataFrame Statistics ==={RESET}")
        print(self.df.describe())
        print(f"\n{BOLD}Distinct values per column:{RESET}")
        for col in self.df.columns:
            nunique = self.df[col].nunique()
            print(f"  {YELLOW}{col}:{RESET} {nunique} distinct value(s)")
            # If distinct values < 10, show value counts
            if nunique <= 10:
                print(f"    Value counts:")
                for val, count in self.df[col].value_counts().items():
                    print(f"      {val}: {count}")
                print()
    def get_all_df_column_names(self):
        """
        Get a list of all column names in the DataFrame.

        Returns:
            List of column names.
        """
        if self.df is not None:
            return self.df.columns.tolist()
        return []

    def get_descriptive_statistics(self):
        """
        Get descriptive statistics of the DataFrame.

        Returns:
            DataFrame containing descriptive statistics, or None if DataFrame is None.
        """
        if self.df is not None:
            return self.df.describe()
        return None

    def get_row_count(self):
        """
        Get the number of rows in the DataFrame.

        Returns:
            Number of rows in the DataFrame, or 0 if DataFrame is None.
        """
        if self.df is not None:
            return len(self.df)
        return 0

    def get_row_by_index(self, index: int) -> Optional[pd.Series]:
        """
        Get a row from the DataFrame by its index.

        Args:
            index: Index of the row to retrieve.
        Returns:
            Row as a pandas Series if index is valid, else None.
        """
        if self.df is not None and 0 <= index < len(self.df):
            return self.df.iloc[index]
        return None

    def get_all_document_ids(self):
        """
        Get a list of all document IDs in the corpus.

        Returns:
            List of document IDs.
        """
        return [doc.id for doc in self.documents]

    def get_document_by_id(self, document_id: str) -> Optional[Document]:
        """
        Get a document by its ID.

        Args:
            document_id: ID of the document to retrieve.

        Returns:
            Document object if found, else None.
        """
        for doc in self.documents:
            if doc.id == document_id:
                return doc
        return None

    def add_document(self, document: Document):
        """
        Add a document to the corpus.

        Args:
            document: Document object to add.
        """
        self.documents.append(document)

    def remove_document_by_id(self, document_id: str):
        """
        Remove a document from the corpus by its ID.

        Args:
            document_id: ID of the document to remove.
        """
        self.documents = [
            doc for doc in self.documents if doc.id != document_id
        ]

    def update_metadata(self, key: str, value: Any):
        """
        Update the metadata of the corpus.

        Args:
            key: Metadata key to update.
            value: New value for the metadata key.
        """
        self.metadata[key] = value

    def add_relationship(self, first: str, second: str, relation: str):
        """
        Add a relationship between two documents in the corpus.

        Args:
            first: keywords from text documents in the format text:keyword or columns from dataframe in the format numb:column
            second: keywords from text documents in the format text:keyword or columns from dataframe in the format numb:column
            relation: Description of the relationship. (One of "correlates", "similar to", "cites", "references", "contradicts", etc.)
        """
        if "relationships" not in self.metadata:
            self.metadata["relationships"] = []
        self.metadata["relationships"].append(
            {"first": first, "second": second, "relation": relation}
        )

    def clear_relationships(self):
        """
        Clear all relationships in the corpus metadata.
        """
        if "relationships" in self.metadata:
            self.metadata["relationships"] = []

    def get_relationships(self):
        """
        Get all relationships in the corpus metadata.

        Returns:
            List of relationships, or empty list if none exist.
        """
        return self.metadata.get("relationships", [])

    def get_all_relationships_for_keyword(self, keyword: str):
        """
        Get all relationships involving a specific keyword.

        Args:
            keyword: Keyword to search for in relationships.

        Returns:
            List of relationships involving the keyword.
        """
        rels = self.get_relationships()
        return [
            rel
            for rel in rels
            if keyword in rel["first"] or keyword in rel["second"]
        ]

add_document(document)

Add a document to the corpus.

Parameters:

Name Type Description Default
document Document

Document object to add.

required
Source code in src/crisp_t/model/corpus.py
306
307
308
309
310
311
312
313
def add_document(self, document: Document):
    """
    Add a document to the corpus.

    Args:
        document: Document object to add.
    """
    self.documents.append(document)

add_relationship(first, second, relation)

Add a relationship between two documents in the corpus.

Parameters:

Name Type Description Default
first str

keywords from text documents in the format text:keyword or columns from dataframe in the format numb:column

required
second str

keywords from text documents in the format text:keyword or columns from dataframe in the format numb:column

required
relation str

Description of the relationship. (One of "correlates", "similar to", "cites", "references", "contradicts", etc.)

required
Source code in src/crisp_t/model/corpus.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def add_relationship(self, first: str, second: str, relation: str):
    """
    Add a relationship between two documents in the corpus.

    Args:
        first: keywords from text documents in the format text:keyword or columns from dataframe in the format numb:column
        second: keywords from text documents in the format text:keyword or columns from dataframe in the format numb:column
        relation: Description of the relationship. (One of "correlates", "similar to", "cites", "references", "contradicts", etc.)
    """
    if "relationships" not in self.metadata:
        self.metadata["relationships"] = []
    self.metadata["relationships"].append(
        {"first": first, "second": second, "relation": relation}
    )

clear_relationships()

Clear all relationships in the corpus metadata.

Source code in src/crisp_t/model/corpus.py
351
352
353
354
355
356
def clear_relationships(self):
    """
    Clear all relationships in the corpus metadata.
    """
    if "relationships" in self.metadata:
        self.metadata["relationships"] = []

get_all_df_column_names()

Get a list of all column names in the DataFrame.

Returns:

Type Description

List of column names.

Source code in src/crisp_t/model/corpus.py
236
237
238
239
240
241
242
243
244
245
def get_all_df_column_names(self):
    """
    Get a list of all column names in the DataFrame.

    Returns:
        List of column names.
    """
    if self.df is not None:
        return self.df.columns.tolist()
    return []

get_all_document_ids()

Get a list of all document IDs in the corpus.

Returns:

Type Description

List of document IDs.

Source code in src/crisp_t/model/corpus.py
282
283
284
285
286
287
288
289
def get_all_document_ids(self):
    """
    Get a list of all document IDs in the corpus.

    Returns:
        List of document IDs.
    """
    return [doc.id for doc in self.documents]

get_all_relationships_for_keyword(keyword)

Get all relationships involving a specific keyword.

Parameters:

Name Type Description Default
keyword str

Keyword to search for in relationships.

required

Returns:

Type Description

List of relationships involving the keyword.

Source code in src/crisp_t/model/corpus.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
def get_all_relationships_for_keyword(self, keyword: str):
    """
    Get all relationships involving a specific keyword.

    Args:
        keyword: Keyword to search for in relationships.

    Returns:
        List of relationships involving the keyword.
    """
    rels = self.get_relationships()
    return [
        rel
        for rel in rels
        if keyword in rel["first"] or keyword in rel["second"]
    ]

get_descriptive_statistics()

Get descriptive statistics of the DataFrame.

Returns:

Type Description

DataFrame containing descriptive statistics, or None if DataFrame is None.

Source code in src/crisp_t/model/corpus.py
247
248
249
250
251
252
253
254
255
256
def get_descriptive_statistics(self):
    """
    Get descriptive statistics of the DataFrame.

    Returns:
        DataFrame containing descriptive statistics, or None if DataFrame is None.
    """
    if self.df is not None:
        return self.df.describe()
    return None

get_document_by_id(document_id)

Get a document by its ID.

Parameters:

Name Type Description Default
document_id str

ID of the document to retrieve.

required

Returns:

Type Description
Optional[Document]

Document object if found, else None.

Source code in src/crisp_t/model/corpus.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def get_document_by_id(self, document_id: str) -> Optional[Document]:
    """
    Get a document by its ID.

    Args:
        document_id: ID of the document to retrieve.

    Returns:
        Document object if found, else None.
    """
    for doc in self.documents:
        if doc.id == document_id:
            return doc
    return None

get_relationships()

Get all relationships in the corpus metadata.

Returns:

Type Description

List of relationships, or empty list if none exist.

Source code in src/crisp_t/model/corpus.py
358
359
360
361
362
363
364
365
def get_relationships(self):
    """
    Get all relationships in the corpus metadata.

    Returns:
        List of relationships, or empty list if none exist.
    """
    return self.metadata.get("relationships", [])

get_row_by_index(index)

Get a row from the DataFrame by its index.

Parameters:

Name Type Description Default
index int

Index of the row to retrieve.

required

Returns: Row as a pandas Series if index is valid, else None.

Source code in src/crisp_t/model/corpus.py
269
270
271
272
273
274
275
276
277
278
279
280
def get_row_by_index(self, index: int) -> Optional[pd.Series]:
    """
    Get a row from the DataFrame by its index.

    Args:
        index: Index of the row to retrieve.
    Returns:
        Row as a pandas Series if index is valid, else None.
    """
    if self.df is not None and 0 <= index < len(self.df):
        return self.df.iloc[index]
    return None

get_row_count()

Get the number of rows in the DataFrame.

Returns:

Type Description

Number of rows in the DataFrame, or 0 if DataFrame is None.

Source code in src/crisp_t/model/corpus.py
258
259
260
261
262
263
264
265
266
267
def get_row_count(self):
    """
    Get the number of rows in the DataFrame.

    Returns:
        Number of rows in the DataFrame, or 0 if DataFrame is None.
    """
    if self.df is not None:
        return len(self.df)
    return 0

pretty_print(show='all')

Print the corpus information in a human-readable format.

Parameters:

Name Type Description Default
show

Display option. Can be: - "all": Show all corpus information - "documents": Show first 5 documents - "documents N": Show first N documents (e.g., "documents 10") - "documents metadata": Show document-specific metadata - "dataframe": Show DataFrame head - "dataframe metadata": Show DataFrame metadata columns (metadata_*) - "dataframe stats": Show DataFrame statistics - "metadata": Show all corpus metadata - "metadata KEY": Show specific metadata key (e.g., "metadata pca") - "stats": Show DataFrame statistics (deprecated, use "dataframe stats")

'all'
Source code in src/crisp_t/model/corpus.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def pretty_print(self, show="all"):
    """
    Print the corpus information in a human-readable format.

    Args:
        show: Display option. Can be:
            - "all": Show all corpus information
            - "documents": Show first 5 documents
            - "documents N": Show first N documents (e.g., "documents 10")
            - "documents metadata": Show document-specific metadata
            - "dataframe": Show DataFrame head
            - "dataframe metadata": Show DataFrame metadata columns (metadata_*)
            - "dataframe stats": Show DataFrame statistics
            - "metadata": Show all corpus metadata
            - "metadata KEY": Show specific metadata key (e.g., "metadata pca")
            - "stats": Show DataFrame statistics (deprecated, use "dataframe stats")
    """
    # Color codes for terminal output
    BLUE = "\033[94m"
    GREEN = "\033[92m"
    YELLOW = "\033[93m"
    CYAN = "\033[96m"
    MAGENTA = "\033[95m"
    RED = "\033[91m"
    RESET = "\033[0m"
    BOLD = "\033[1m"

    # Parse the show parameter to support subcommands
    parts = show.split(maxsplit=1)
    main_command = parts[0]
    sub_command = parts[1] if len(parts) > 1 else None

    # Print basic corpus info for most commands
    if main_command in ["all", "documents", "dataframe", "metadata"]:
        print(f"{BOLD}{BLUE}Corpus ID:{RESET} {self.id}")
        print(f"{BOLD}{BLUE}Name:{RESET} {self.name}")
        print(f"{BOLD}{BLUE}Description:{RESET} {self.description}")

    # Handle documents command
    if main_command in ["all", "documents"]:
        if sub_command == "metadata":
            # Show document-specific metadata
            print(f"\n{BOLD}{GREEN}=== Document Metadata ==={RESET}")
            if not self.documents:
                print("No documents in corpus")
            else:
                for i, doc in enumerate(self.documents, 1):
                    print(f"\n{CYAN}Document {i}:{RESET}")
                    print(f"  {BOLD}ID:{RESET} {doc.id}")
                    print(f"  {BOLD}Name:{RESET} {doc.name}")
                    if doc.metadata:
                        print(f"  {BOLD}Metadata:{RESET}")
                        for key, value in doc.metadata.items():
                            # Truncate long values
                            val_str = str(value)
                            if len(val_str) > 100:
                                val_str = val_str[:97] + "..."
                            print(f"    {YELLOW}{key}:{RESET} {val_str}")
                    else:
                        print(f"  {BOLD}Metadata:{RESET} (none)")
        else:
            # Determine how many documents to show
            num_docs = 5  # default
            if sub_command:
                try:
                    num_docs = int(sub_command)
                except ValueError:
                    print(f"{RED}Invalid number for documents: {sub_command}. Using default (5).{RESET}")

            print(f"\n{BOLD}{GREEN}=== Documents ==={RESET}")
            print(f"Total documents: {len(self.documents)}")
            print(f"Showing first {min(num_docs, len(self.documents))} document(s):\n")

            for i, doc in enumerate(self.documents[:num_docs], 1):
                print(f"{CYAN}Document {i}:{RESET}")
                print(f"  {BOLD}Name:{RESET} {doc.name}")
                print(f"  {BOLD}ID:{RESET} {doc.id}")
                # Show a snippet of text if available
                if hasattr(doc, 'text') and doc.text:
                    text_snippet = doc.text[:200] + "..." if len(doc.text) > 200 else doc.text
                    print(f"  {BOLD}Text:{RESET} {text_snippet}")
                print()

    # Handle dataframe command
    if main_command in ["all", "dataframe"]:
        if self.df is not None:
            if sub_command == "metadata":
                # Show DataFrame metadata columns (columns starting with metadata_)
                print(f"\n{BOLD}{GREEN}=== DataFrame Metadata Columns ==={RESET}")
                metadata_cols = [col for col in self.df.columns if col.startswith("metadata_")]
                if metadata_cols:
                    print(f"Found {len(metadata_cols)} metadata column(s):")
                    for col in metadata_cols:
                        print(f"  {YELLOW}{col}{RESET}")
                        # Show some statistics for the metadata column
                        print(f"    Non-null values: {self.df[col].notna().sum()}")
                        print(f"    Null values: {self.df[col].isna().sum()}")
                        # Show unique values if not too many
                        unique_count = self.df[col].nunique()
                        if unique_count <= 10:
                            print(f"    Unique values ({unique_count}): {list(self.df[col].unique())}")
                        else:
                            print(f"    Unique values: {unique_count}")
                else:
                    print("No metadata columns found (columns starting with 'metadata_')")
            elif sub_command == "stats":
                # Show DataFrame statistics
                self._print_dataframe_stats()
            else:
                # Show DataFrame head
                print(f"\n{BOLD}{GREEN}=== DataFrame ==={RESET}")
                print(f"Shape: {self.df.shape}")
                print(f"Columns: {list(self.df.columns)}")
                print("\nFirst few rows:")
                print(self.df.head())
        else:
            if main_command == "dataframe":
                print(f"\n{BOLD}{RED}No DataFrame available{RESET}")

    # Handle metadata command
    if main_command in ["all", "metadata"]:
        if sub_command:
            # Show specific metadata key
            print(f"\n{BOLD}{GREEN}=== Metadata: {sub_command} ==={RESET}")
            if sub_command in self.metadata:
                value = self.metadata[sub_command]
                # Format the output based on the type of value
                if isinstance(value, dict):
                    for k, v in value.items():
                        print(f"{YELLOW}{k}:{RESET} {v}")
                elif isinstance(value, list):
                    for i, item in enumerate(value, 1):
                        print(f"{i}. {item}")
                else:
                    print(value)
            else:
                print(f"{RED}Metadata key '{sub_command}' not found{RESET}")
                available_keys = list(self.metadata.keys())
                if available_keys:
                    print(f"Available keys: {', '.join(available_keys)}")
        else:
            # Show all metadata
            print(f"\n{BOLD}{GREEN}=== Corpus Metadata ==={RESET}")
            if not self.metadata:
                print("No metadata available")
            else:
                for key, value in self.metadata.items():
                    print(f"\n{MAGENTA}{key}:{RESET}")
                    # Truncate long values
                    val_str = str(value)
                    if len(val_str) > 500:
                        val_str = val_str[:497] + "..."
                    print(f"  {val_str}")

    # Handle stats command (deprecated, redirect to dataframe stats)
    if main_command == "stats":
        print(f"{YELLOW}Note: 'stats' is deprecated. Use 'dataframe stats' instead.{RESET}")
        if self.df is not None:
            self._print_dataframe_stats()
        else:
            print(f"{RED}No DataFrame available{RESET}")

    print(f"\n{BOLD}Display completed for '{show}'{RESET}")

remove_document_by_id(document_id)

Remove a document from the corpus by its ID.

Parameters:

Name Type Description Default
document_id str

ID of the document to remove.

required
Source code in src/crisp_t/model/corpus.py
315
316
317
318
319
320
321
322
323
324
def remove_document_by_id(self, document_id: str):
    """
    Remove a document from the corpus by its ID.

    Args:
        document_id: ID of the document to remove.
    """
    self.documents = [
        doc for doc in self.documents if doc.id != document_id
    ]

update_metadata(key, value)

Update the metadata of the corpus.

Parameters:

Name Type Description Default
key str

Metadata key to update.

required
value Any

New value for the metadata key.

required
Source code in src/crisp_t/model/corpus.py
326
327
328
329
330
331
332
333
334
def update_metadata(self, key: str, value: Any):
    """
    Update the metadata of the corpus.

    Args:
        key: Metadata key to update.
        value: New value for the metadata key.
    """
    self.metadata[key] = value

Copyright (C) 2025 Bell Eapen

This file is part of crisp-t.

crisp-t is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

crisp-t is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with crisp-t. If not, see https://www.gnu.org/licenses/.

Document

Bases: BaseModel

Document model for storing text and metadata.

Source code in src/crisp_t/model/document.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class Document(BaseModel):
    """
    Document model for storing text and metadata.
    """
    id: str = Field(..., description="Unique identifier for the document.")
    name: Optional[str] = Field(None, description="Name of the corpus.")
    description: Optional[str] = Field(None, description="Description of the corpus.")
    score: float = Field(0.0, description="Score associated with the document.")
    text: str = Field(..., description="The text content of the document.")
    metadata: dict = Field(
        default_factory=dict, description="Metadata associated with the document."
    )

    def pretty_print(self):
        """
        Print the document information in a human-readable format.
        """
        print(f"Document ID: {self.id}")
        print(f"Name: {self.name}")
        print(f"Description: {self.description}")
        print(f"Score: {self.score}")
        print(f"Text: {self.text[:100]}...")  # Print first 100 characters of text
        print(f"Metadata: {self.metadata}")
        print(f"Length of Text: {len(self.text)} characters")
        print(f"Number of Metadata Entries: {len(self.metadata)}")

pretty_print()

Print the document information in a human-readable format.

Source code in src/crisp_t/model/document.py
37
38
39
40
41
42
43
44
45
46
47
48
def pretty_print(self):
    """
    Print the document information in a human-readable format.
    """
    print(f"Document ID: {self.id}")
    print(f"Name: {self.name}")
    print(f"Description: {self.description}")
    print(f"Score: {self.score}")
    print(f"Text: {self.text[:100]}...")  # Print first 100 characters of text
    print(f"Metadata: {self.metadata}")
    print(f"Length of Text: {len(self.text)} characters")
    print(f"Number of Metadata Entries: {len(self.metadata)}")

Csv

Source code in src/crisp_t/csv.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
class Csv:

    def __init__(
        self,
        corpus: Optional[Corpus] = None,
        comma_separated_text_columns: str = "",
        comma_separated_ignore_columns: str = "",
        id_column: str = "id",
    ):
        """
        Initialize the Csv object.
        """
        self._corpus = corpus
        if self._corpus is None:
            self._df = pd.DataFrame()
            logger.info("No corpus provided. Creating an empty DataFrame.")
        else:
            self._df = self._corpus.df
            if self._df is None:
                logger.info("No DataFrame found in the corpus. Creating a new one.")
                self._df = pd.DataFrame()
        self._df_original = self._df.copy()
        self._comma_separated_text_columns = comma_separated_text_columns
        self._comma_separated_ignore_columns = comma_separated_ignore_columns
        self._id_column = id_column
        self._X = None
        self._y = None
        self._X_original = None
        self._y_original = None
        self._id_column = id_column

    @property
    def corpus(self) -> Optional[Corpus]:
        if self._corpus is not None and self._df is not None:
            self._corpus.df = self._df
        return self._corpus

    @property
    def df(self) -> pd.DataFrame:
        if self._df is None:
            return pd.DataFrame()
        return self._df

    @property
    def comma_separated_text_columns(self) -> str:
        return self._comma_separated_text_columns

    @property
    def comma_separated_ignore_columns(self) -> str:
        return self._comma_separated_ignore_columns

    @comma_separated_ignore_columns.setter
    def comma_separated_ignore_columns(self, value: str) -> None:
        self._comma_separated_ignore_columns = value
        logger.info("Comma-separated ignore columns set successfully.")
        logger.debug(
            f"Comma-separated ignore columns: {self._comma_separated_ignore_columns}"
        )
        self._process_columns()

    @property
    def id_column(self) -> str:
        return self._id_column

    @corpus.setter
    def corpus(self, value: Corpus) -> None:
        self._corpus = value
        if self._corpus is not None:
            self._df = self._corpus.df
            if self._df is None:
                logger.info("No DataFrame found in the corpus. Creating a new one.")
                self._df = pd.DataFrame()
            self._df_original = self._df.copy()
            logger.info("Corpus set successfully.")
            logger.debug(f"DataFrame content: {self._df.head()}")
            logger.debug(f"DataFrame shape: {self._df.shape}")
            logger.debug(f"DataFrame columns: {self._df.columns.tolist()}")
        else:
            logger.error("Failed to set corpus. Corpus is None.")

    @df.setter
    def df(self, value: pd.DataFrame) -> None:
        self._df = value
        logger.info("DataFrame set successfully.")
        logger.debug(f"DataFrame content: {self._df.head()}")
        logger.debug(f"DataFrame shape: {self._df.shape}")
        logger.debug(f"DataFrame columns: {self._df.columns.tolist()}")

    @comma_separated_text_columns.setter
    def comma_separated_text_columns(self, value: str) -> None:
        self._comma_separated_text_columns = value
        logger.info("Comma-separated text columns set successfully.")
        logger.debug(
            f"Comma-separated text columns: {self._comma_separated_text_columns}"
        )
        self._process_columns()

    @id_column.setter
    def id_column(self, value: str) -> None:
        self._id_column = value
        # Add id column to the list of ignored columns
        ignore_cols = [
            col
            for col in self._comma_separated_ignore_columns.split(",")
            if col.strip()
        ]
        if value not in ignore_cols:
            ignore_cols.append(value)
            self._comma_separated_ignore_columns = ",".join(ignore_cols)
            logger.debug(
                f"ID column '{value}' added to ignore columns: {self._comma_separated_ignore_columns}"
            )
        logger.info("ID column set successfully.")
        logger.debug(f"ID column: {self._id_column}")

    # TODO remove @deprecated
    #! Do not use
    def read_csv(self, file_path: str):
        """
        Read a CSV file and create a DataFrame.
        """
        try:
            self._df = pd.read_csv(file_path)
            logger.info(f"CSV file {file_path} read successfully.")
            logger.debug(f"DataFrame content: {self._df.head()}")
            logger.debug(f"DataFrame shape: {self._df.shape}")
            logger.debug(f"DataFrame columns: {self._df.columns.tolist()}")
        except Exception as e:
            logger.error(f"Error reading CSV file: {e}")
            raise
        return self._process_columns()

    def _process_columns(self):
        # ignore comma-separated ignore columns
        if self._comma_separated_ignore_columns:
            ignore_columns = [
                col.strip()
                for col in self._comma_separated_ignore_columns.split(",")
                if col.strip()
            ]
            self._df.drop(columns=ignore_columns, inplace=True, errors="ignore")
            logger.info(
                f"Ignored columns: {ignore_columns}. Updated DataFrame shape: {self._df.shape}"
            )
            logger.debug(f"DataFrame content after dropping columns: {self._df.head()}")
        # ignore comma-separated text columns
        if self._comma_separated_text_columns:
            text_columns = [
                col.strip()
                for col in self._comma_separated_text_columns.split(",")
                if col.strip()
            ]
            for col in text_columns:
                if col in self._df.columns:
                    self._df[col] = self._df[col].astype(str)
                    logger.info(f"Column {col} converted to string.")
                    logger.debug(f"Column {col} content: {self._df[col].head()}")
                else:
                    logger.warning(f"Column {col} not found in DataFrame.")
        # ignore all columns with names starting with "metadata_"
        self._df = self._df.loc[:, ~self._df.columns.str.startswith("metadata_")]
        return self._df

    def write_csv(self, file_path: str, index: bool = False) -> None:
        if self._df is not None:
            self._df.to_csv(file_path, index=index)
            logger.info(f"DataFrame written to {file_path}")
            logger.debug(f"DataFrame content: {self._df.head()}")
            logger.debug(f"Index: {index}")
        else:
            logger.error("DataFrame is None. Cannot write to CSV.")

    def mark_missing(self):
        """Mark missing values in the DataFrame.
        Missing values are considered as empty strings and are replaced with NaN.
        Rows with NaN values are then dropped from the DataFrame.
        """
        if self._df is not None:
            self._df.replace("", np.nan, inplace=True)
            self._df.dropna(inplace=True)
        else:
            logger.error("DataFrame is None. Cannot mark missing values.")

    def mark_duplicates(self):
        """Mark duplicate rows in the DataFrame.
        Duplicate rows are identified and dropped from the DataFrame.
        """
        if self._df is not None:
            self._df.drop_duplicates(inplace=True)
        else:
            logger.error("DataFrame is None. Cannot mark duplicates.")

    def restore_df(self):
        self._df = self._df_original.copy()

    def get_shape(self):
        if self._df is not None:
            return self._df.shape
        else:
            logger.error("DataFrame is None. Cannot get shape.")
            return None

    def get_columns(self):
        """Get the list of columns in the DataFrame."""
        if self._df is not None:
            return self._df.columns.tolist()
        else:
            logger.error("DataFrame is None. Cannot get columns.")
            return []

    def get_column_types(self):
        """Get the data types of columns in the DataFrame."""
        if self._df is not None:
            return self._df.dtypes.to_dict()
        else:
            logger.error("DataFrame is None. Cannot get column types.")
            return {}

    def get_column_values(self, column_name: str):
        """Get the unique values in a column of the DataFrame."""
        if self._df is not None and column_name in self._df.columns:
            return self._df[column_name].tolist()
        else:
            logger.error(
                f"Column {column_name} not found in DataFrame or DataFrame is None."
            )
            return None

    def retain_numeric_columns_only(self):
        """Retain only numeric columns in the DataFrame."""
        if self._df is not None:
            self._df = self._df.select_dtypes(include=[np.number])
            logger.info("DataFrame filtered to numeric columns only.")
        else:
            logger.error("DataFrame is None. Cannot filter to numeric columns.")

    def comma_separated_include_columns(self, include_cols: str = ""):
        """Retain only specified columns in the DataFrame."""
        if include_cols == "":
            return
        if self._df is not None:
            cols = [
                col.strip()
                for col in include_cols.split(",")
                if col.strip() and col in self._df.columns
            ]
            self._df = self._df[cols]
            logger.info(f"DataFrame filtered to include columns: {cols}")
        else:
            logger.error("DataFrame is None. Cannot filter to include columns.")

    def read_xy(self, y: str):
        """
        Read X and y variables from the DataFrame.
        """
        if self._df is None:
            logger.error("DataFrame is None. Cannot read X and y.")
            return None, None
        # Split into X and y
        if y == "":
            self._y = None
        else:
            self._y = self._df[y]
        if y != "":
            self._X = self._df.drop(columns=[y])
        else:
            self._X = self._df.copy()
        logger.info(f"X and y variables set. X shape: {self._X.shape}")
        return self._X, self._y

    def drop_na(self):
        """Drop rows with any NA values from the DataFrame."""
        if self._df is not None:
            self._df.dropna(inplace=True)
            logger.info("Missing values dropped from DataFrame.")
        else:
            logger.error("DataFrame is None. Cannot drop missing values.")

    def oversample(self, mcp: bool = False):
        self._X_original = self._X
        self._y_original = self._y
        try:
            from imblearn.over_sampling import RandomOverSampler

            ros = RandomOverSampler(random_state=0)
        except ImportError:
            logger.info(
                "ML dependencies are not installed. Please install them by ```pip install crisp-t[ml] to use ML features."
            )
            return

        result = ros.fit_resample(self._X, self._y)
        if len(result) == 2:
            X, y = result
        elif len(result) == 3:
            X, y, _ = result
        else:
            logger.error("Unexpected number of values returned from fit_resample.")
            return
        self._X = X
        self._y = y
        if mcp:
            return f"Oversampling completed. New X shape: {self._X.shape}"
        return X, y

    def restore_oversample(self, mcp: bool = False):
        self._X = self._X_original
        self._y = self._y_original
        if mcp:
            return f"Oversampling restored. X shape: {self._X.shape}, y shape: {self._y.shape}"  # type: ignore

    def prepare_data(self, y: str, oversample=False, one_hot_encode_all=False):
        self.mark_missing()
        if oversample:
            self.oversample()
        self.one_hot_encode_strings_in_df()
        if one_hot_encode_all:
            self.one_hot_encode_all_columns()
        return self.read_xy(y)

    def bin_a_column(self, column_name: str, bins: int = 2):
        """Bin a numeric column into specified number of bins."""
        if self._df is not None and column_name in self._df.columns:
            if pd.api.types.is_numeric_dtype(self._df[column_name]):
                self._df[column_name] = pd.cut(
                    self._df[column_name], bins=bins, labels=False
                )
                logger.info(f"Column {column_name} binned into {bins} bins.")
                return "I have binned the column. Please proceed."
            else:
                logger.warning(f"Column {column_name} is not numeric. Cannot bin.")
        else:
            logger.warning(
                f"Column {column_name} not found in DataFrame or DataFrame is None."
            )
        return "I cannot bin the column. Please check the logs for more information."

    def one_hot_encode_column(self, column_name: str):
        """One-hot encode a specific column in the DataFrame.
        This method converts a categorical column into one-hot encoded columns.
        Used when # ValueError: could not convert string to float.
        """
        if self._df is not None and column_name in self._df.columns:
            if pd.api.types.is_object_dtype(self._df[column_name]):
                self._df = pd.get_dummies(
                    self._df, columns=[column_name], drop_first=True
                )
                logger.info(f"One-hot encoding applied to column {column_name}.")
                return "I have one-hot encoded the column. Please proceed."
            else:
                logger.warning(f"Column {column_name} is not of object type.")
        else:
            logger.error(
                f"Column {column_name} not found in DataFrame or DataFrame is None."
            )
        return "I cannot one-hot encode the column. Please check the logs for more information."

    def one_hot_encode_strings_in_df(self, n=10, filter_high_cardinality=False):
        """One-hot encode string (object) columns in the DataFrame.
        This method converts categorical string columns into one-hot encoded columns.
        Columns with more than n unique values can be optionally filtered out.
        Used when # ValueError: could not convert string to float.
        """
        if self._df is not None:
            categorical_cols = self._df.select_dtypes(
                include=["object"]
            ).columns.tolist()
            # Remove categorical columns with more than n unique values
            if filter_high_cardinality:
                categorical_cols = [
                    col for col in categorical_cols if self._df[col].nunique() <= n
                ]
            if categorical_cols:
                self._df = pd.get_dummies(
                    self._df, columns=categorical_cols, drop_first=True
                )
                logger.info("One-hot encoding applied to string columns.")
            else:
                logger.info("No string (object) columns found for one-hot encoding.")
        else:
            logger.error("DataFrame is None. Cannot apply one-hot encoding.")

    def one_hot_encode_all_columns(self):
        """One-hot encode all columns in the DataFrame.
        This method converts all values in the DataFrame to boolean values.
        Used for apriori algorithm which requires boolean values.
        """
        if self._df is not None:

            def to_one_hot(x):
                if x in [1, True]:
                    return True
                elif x in [0, False]:
                    return False
                else:
                    # logger.warning(
                    #     f"Unexpected value '{x}' encountered during one-hot encoding; mapping to 1."
                    # )
                    return True

            self._df = self._df.applymap(to_one_hot)  # type: ignore

    def filter_rows_by_column_value(
        self, column_name: str, value, mcp: bool = False
    ):
        """Select rows from the DataFrame where the specified column matches the given value.
        Additionally, filter self._corpus.documents by id_column if present in DataFrame.
        """
        if self._df is not None and column_name in self._df.columns:
            selected_df = self._df[self._df[column_name] == value]
            if selected_df.empty:
                # try int search
                try:
                    selected_df = self._df[self._df[column_name] == int(value)]
                except (ValueError, TypeError):
                    logger.warning(
                        f"Could not convert value '{value}' to int for column '{column_name}'."
                    )
            logger.info(
                f"Selected {selected_df.shape[0]} rows where {column_name} == {value}."
            )
            self._df = selected_df

            # Check for id_column in DataFrame
            if (
                self._corpus is not None
                and hasattr(self._corpus, "df")
                and self._id_column in self._corpus.df.columns
            ):
                logger.info(f"id_column '{self._id_column}' exists in DataFrame.")
                valid_ids = set(self._corpus.df[self._id_column].tolist())
                if (
                    hasattr(self._corpus, "documents")
                    and self._corpus.documents is not None
                ):
                    filtered_docs = [
                        doc
                        for doc in self._corpus.documents
                        if getattr(doc, self._id_column, None) in valid_ids
                    ]
                    self._corpus.documents = filtered_docs
            else:
                logger.warning(f"id_column '{self._id_column}' does not exist in DataFrame.")

            if mcp:
                return f"Selected {selected_df.shape[0]} rows where {column_name} == {value}."
        else:
            logger.warning(
                f"Column {column_name} not found in DataFrame or DataFrame is None."
            )
            if mcp:
                return (
                    f"Column {column_name} not found in DataFrame or DataFrame is None."
                )
            return pd.DataFrame()

__init__(corpus=None, comma_separated_text_columns='', comma_separated_ignore_columns='', id_column='id')

Initialize the Csv object.

Source code in src/crisp_t/csv.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    corpus: Optional[Corpus] = None,
    comma_separated_text_columns: str = "",
    comma_separated_ignore_columns: str = "",
    id_column: str = "id",
):
    """
    Initialize the Csv object.
    """
    self._corpus = corpus
    if self._corpus is None:
        self._df = pd.DataFrame()
        logger.info("No corpus provided. Creating an empty DataFrame.")
    else:
        self._df = self._corpus.df
        if self._df is None:
            logger.info("No DataFrame found in the corpus. Creating a new one.")
            self._df = pd.DataFrame()
    self._df_original = self._df.copy()
    self._comma_separated_text_columns = comma_separated_text_columns
    self._comma_separated_ignore_columns = comma_separated_ignore_columns
    self._id_column = id_column
    self._X = None
    self._y = None
    self._X_original = None
    self._y_original = None
    self._id_column = id_column

bin_a_column(column_name, bins=2)

Bin a numeric column into specified number of bins.

Source code in src/crisp_t/csv.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def bin_a_column(self, column_name: str, bins: int = 2):
    """Bin a numeric column into specified number of bins."""
    if self._df is not None and column_name in self._df.columns:
        if pd.api.types.is_numeric_dtype(self._df[column_name]):
            self._df[column_name] = pd.cut(
                self._df[column_name], bins=bins, labels=False
            )
            logger.info(f"Column {column_name} binned into {bins} bins.")
            return "I have binned the column. Please proceed."
        else:
            logger.warning(f"Column {column_name} is not numeric. Cannot bin.")
    else:
        logger.warning(
            f"Column {column_name} not found in DataFrame or DataFrame is None."
        )
    return "I cannot bin the column. Please check the logs for more information."

comma_separated_include_columns(include_cols='')

Retain only specified columns in the DataFrame.

Source code in src/crisp_t/csv.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def comma_separated_include_columns(self, include_cols: str = ""):
    """Retain only specified columns in the DataFrame."""
    if include_cols == "":
        return
    if self._df is not None:
        cols = [
            col.strip()
            for col in include_cols.split(",")
            if col.strip() and col in self._df.columns
        ]
        self._df = self._df[cols]
        logger.info(f"DataFrame filtered to include columns: {cols}")
    else:
        logger.error("DataFrame is None. Cannot filter to include columns.")

drop_na()

Drop rows with any NA values from the DataFrame.

Source code in src/crisp_t/csv.py
287
288
289
290
291
292
293
def drop_na(self):
    """Drop rows with any NA values from the DataFrame."""
    if self._df is not None:
        self._df.dropna(inplace=True)
        logger.info("Missing values dropped from DataFrame.")
    else:
        logger.error("DataFrame is None. Cannot drop missing values.")

filter_rows_by_column_value(column_name, value, mcp=False)

Select rows from the DataFrame where the specified column matches the given value. Additionally, filter self._corpus.documents by id_column if present in DataFrame.

Source code in src/crisp_t/csv.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
def filter_rows_by_column_value(
    self, column_name: str, value, mcp: bool = False
):
    """Select rows from the DataFrame where the specified column matches the given value.
    Additionally, filter self._corpus.documents by id_column if present in DataFrame.
    """
    if self._df is not None and column_name in self._df.columns:
        selected_df = self._df[self._df[column_name] == value]
        if selected_df.empty:
            # try int search
            try:
                selected_df = self._df[self._df[column_name] == int(value)]
            except (ValueError, TypeError):
                logger.warning(
                    f"Could not convert value '{value}' to int for column '{column_name}'."
                )
        logger.info(
            f"Selected {selected_df.shape[0]} rows where {column_name} == {value}."
        )
        self._df = selected_df

        # Check for id_column in DataFrame
        if (
            self._corpus is not None
            and hasattr(self._corpus, "df")
            and self._id_column in self._corpus.df.columns
        ):
            logger.info(f"id_column '{self._id_column}' exists in DataFrame.")
            valid_ids = set(self._corpus.df[self._id_column].tolist())
            if (
                hasattr(self._corpus, "documents")
                and self._corpus.documents is not None
            ):
                filtered_docs = [
                    doc
                    for doc in self._corpus.documents
                    if getattr(doc, self._id_column, None) in valid_ids
                ]
                self._corpus.documents = filtered_docs
        else:
            logger.warning(f"id_column '{self._id_column}' does not exist in DataFrame.")

        if mcp:
            return f"Selected {selected_df.shape[0]} rows where {column_name} == {value}."
    else:
        logger.warning(
            f"Column {column_name} not found in DataFrame or DataFrame is None."
        )
        if mcp:
            return (
                f"Column {column_name} not found in DataFrame or DataFrame is None."
            )
        return pd.DataFrame()

get_column_types()

Get the data types of columns in the DataFrame.

Source code in src/crisp_t/csv.py
227
228
229
230
231
232
233
def get_column_types(self):
    """Get the data types of columns in the DataFrame."""
    if self._df is not None:
        return self._df.dtypes.to_dict()
    else:
        logger.error("DataFrame is None. Cannot get column types.")
        return {}

get_column_values(column_name)

Get the unique values in a column of the DataFrame.

Source code in src/crisp_t/csv.py
235
236
237
238
239
240
241
242
243
def get_column_values(self, column_name: str):
    """Get the unique values in a column of the DataFrame."""
    if self._df is not None and column_name in self._df.columns:
        return self._df[column_name].tolist()
    else:
        logger.error(
            f"Column {column_name} not found in DataFrame or DataFrame is None."
        )
        return None

get_columns()

Get the list of columns in the DataFrame.

Source code in src/crisp_t/csv.py
219
220
221
222
223
224
225
def get_columns(self):
    """Get the list of columns in the DataFrame."""
    if self._df is not None:
        return self._df.columns.tolist()
    else:
        logger.error("DataFrame is None. Cannot get columns.")
        return []

mark_duplicates()

Mark duplicate rows in the DataFrame. Duplicate rows are identified and dropped from the DataFrame.

Source code in src/crisp_t/csv.py
200
201
202
203
204
205
206
207
def mark_duplicates(self):
    """Mark duplicate rows in the DataFrame.
    Duplicate rows are identified and dropped from the DataFrame.
    """
    if self._df is not None:
        self._df.drop_duplicates(inplace=True)
    else:
        logger.error("DataFrame is None. Cannot mark duplicates.")

mark_missing()

Mark missing values in the DataFrame. Missing values are considered as empty strings and are replaced with NaN. Rows with NaN values are then dropped from the DataFrame.

Source code in src/crisp_t/csv.py
189
190
191
192
193
194
195
196
197
198
def mark_missing(self):
    """Mark missing values in the DataFrame.
    Missing values are considered as empty strings and are replaced with NaN.
    Rows with NaN values are then dropped from the DataFrame.
    """
    if self._df is not None:
        self._df.replace("", np.nan, inplace=True)
        self._df.dropna(inplace=True)
    else:
        logger.error("DataFrame is None. Cannot mark missing values.")

one_hot_encode_all_columns()

One-hot encode all columns in the DataFrame. This method converts all values in the DataFrame to boolean values. Used for apriori algorithm which requires boolean values.

Source code in src/crisp_t/csv.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def one_hot_encode_all_columns(self):
    """One-hot encode all columns in the DataFrame.
    This method converts all values in the DataFrame to boolean values.
    Used for apriori algorithm which requires boolean values.
    """
    if self._df is not None:

        def to_one_hot(x):
            if x in [1, True]:
                return True
            elif x in [0, False]:
                return False
            else:
                # logger.warning(
                #     f"Unexpected value '{x}' encountered during one-hot encoding; mapping to 1."
                # )
                return True

        self._df = self._df.applymap(to_one_hot)  # type: ignore

one_hot_encode_column(column_name)

One-hot encode a specific column in the DataFrame. This method converts a categorical column into one-hot encoded columns. Used when # ValueError: could not convert string to float.

Source code in src/crisp_t/csv.py
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def one_hot_encode_column(self, column_name: str):
    """One-hot encode a specific column in the DataFrame.
    This method converts a categorical column into one-hot encoded columns.
    Used when # ValueError: could not convert string to float.
    """
    if self._df is not None and column_name in self._df.columns:
        if pd.api.types.is_object_dtype(self._df[column_name]):
            self._df = pd.get_dummies(
                self._df, columns=[column_name], drop_first=True
            )
            logger.info(f"One-hot encoding applied to column {column_name}.")
            return "I have one-hot encoded the column. Please proceed."
        else:
            logger.warning(f"Column {column_name} is not of object type.")
    else:
        logger.error(
            f"Column {column_name} not found in DataFrame or DataFrame is None."
        )
    return "I cannot one-hot encode the column. Please check the logs for more information."

one_hot_encode_strings_in_df(n=10, filter_high_cardinality=False)

One-hot encode string (object) columns in the DataFrame. This method converts categorical string columns into one-hot encoded columns. Columns with more than n unique values can be optionally filtered out. Used when # ValueError: could not convert string to float.

Source code in src/crisp_t/csv.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
def one_hot_encode_strings_in_df(self, n=10, filter_high_cardinality=False):
    """One-hot encode string (object) columns in the DataFrame.
    This method converts categorical string columns into one-hot encoded columns.
    Columns with more than n unique values can be optionally filtered out.
    Used when # ValueError: could not convert string to float.
    """
    if self._df is not None:
        categorical_cols = self._df.select_dtypes(
            include=["object"]
        ).columns.tolist()
        # Remove categorical columns with more than n unique values
        if filter_high_cardinality:
            categorical_cols = [
                col for col in categorical_cols if self._df[col].nunique() <= n
            ]
        if categorical_cols:
            self._df = pd.get_dummies(
                self._df, columns=categorical_cols, drop_first=True
            )
            logger.info("One-hot encoding applied to string columns.")
        else:
            logger.info("No string (object) columns found for one-hot encoding.")
    else:
        logger.error("DataFrame is None. Cannot apply one-hot encoding.")

read_csv(file_path)

Read a CSV file and create a DataFrame.

Source code in src/crisp_t/csv.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def read_csv(self, file_path: str):
    """
    Read a CSV file and create a DataFrame.
    """
    try:
        self._df = pd.read_csv(file_path)
        logger.info(f"CSV file {file_path} read successfully.")
        logger.debug(f"DataFrame content: {self._df.head()}")
        logger.debug(f"DataFrame shape: {self._df.shape}")
        logger.debug(f"DataFrame columns: {self._df.columns.tolist()}")
    except Exception as e:
        logger.error(f"Error reading CSV file: {e}")
        raise
    return self._process_columns()

read_xy(y)

Read X and y variables from the DataFrame.

Source code in src/crisp_t/csv.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def read_xy(self, y: str):
    """
    Read X and y variables from the DataFrame.
    """
    if self._df is None:
        logger.error("DataFrame is None. Cannot read X and y.")
        return None, None
    # Split into X and y
    if y == "":
        self._y = None
    else:
        self._y = self._df[y]
    if y != "":
        self._X = self._df.drop(columns=[y])
    else:
        self._X = self._df.copy()
    logger.info(f"X and y variables set. X shape: {self._X.shape}")
    return self._X, self._y

retain_numeric_columns_only()

Retain only numeric columns in the DataFrame.

Source code in src/crisp_t/csv.py
245
246
247
248
249
250
251
def retain_numeric_columns_only(self):
    """Retain only numeric columns in the DataFrame."""
    if self._df is not None:
        self._df = self._df.select_dtypes(include=[np.number])
        logger.info("DataFrame filtered to numeric columns only.")
    else:
        logger.error("DataFrame is None. Cannot filter to numeric columns.")

Copyright (C) 2025 Bell Eapen

This file is part of crisp-t.

crisp-t is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

crisp-t is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with crisp-t. If not, see https://www.gnu.org/licenses/.

Text

Source code in src/crisp_t/text.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
class Text:

    def __init__(
        self, corpus: Optional[Corpus] = None, lang="en_core_web_sm", max_length=1100000
    ):
        self._corpus = corpus
        self._lang = lang
        self._spacy_manager = SpacyManager(self._lang)
        self._max_length = max_length
        self._initial_document_count = len(self._corpus.documents) if corpus else 0  # type: ignore

        self._spacy_doc = None
        self._lemma = {}
        self._pos = {}
        self._pos_ = {}
        self._word = {}
        self._sentiment = {}
        self._tag = {}
        self._dep = {}
        self._prob = {}
        self._idx = {}

    @property
    def corpus(self):
        """
        Get the corpus.
        """
        if self._corpus is None:
            raise ValueError("Corpus is not set")
        return self._corpus

    @property
    def max_length(self):
        """
        Get the maximum length of the corpus.
        """
        return self._max_length

    @property
    def lang(self):
        """
        Get the language of the corpus.
        """
        return self._lang

    @property
    def initial_document_count(self):
        """
        Get the initial document count.
        """
        return self._initial_document_count

    @corpus.setter
    def corpus(self, corpus: Corpus):
        """
        Set the corpus.
        """
        if not isinstance(corpus, Corpus):
            raise ValueError("Corpus must be of type Corpus")
        self._corpus = corpus
        spacy_doc, results = self.process_tokens(self._corpus.id if self._corpus else None)
        self._spacy_doc = spacy_doc
        self._lemma = results["lemma"]
        self._pos = results["pos"]
        self._pos_ = results["pos_"]
        self._word = results["word"]
        self._sentiment = results["sentiment"]
        self._tag = results["tag"]
        self._dep = results["dep"]
        self._prob = results["prob"]
        self._idx = results["idx"]

    @max_length.setter
    def max_length(self, max_length: int):
        """
        Set the maximum length of the corpus.
        """
        if not isinstance(max_length, int):
            raise ValueError("max_length must be an integer")
        self._max_length = max_length
        if self._spacy_doc is not None:
            self._spacy_doc.max_length = max_length

    @lang.setter
    def lang(self, lang: str):
        """
        Set the language of the corpus.
        """
        if not isinstance(lang, str):
            raise ValueError("lang must be a string")
        self._lang = lang
        spacy_doc, results = self.process_tokens(self._corpus.id if self._corpus else None)
        self._spacy_doc = spacy_doc
        self._lemma = results["lemma"]
        self._pos = results["pos"]
        self._pos_ = results["pos_"]
        self._word = results["word"]
        self._sentiment = results["sentiment"]
        self._tag = results["tag"]
        self._dep = results["dep"]
        self._prob = results["prob"]
        self._idx = results["idx"]

    def make_spacy_doc(self):
        if self._corpus is None:
            raise ValueError("Corpus is not set")
        # Use list and join for efficient string concatenation instead of +=
        text_parts = []
        for document in tqdm(
            self._corpus.documents,
            desc="Processing documents",
            disable=len(self._corpus.documents) < 10,
        ):
            text_parts.append(self.process_text(document.text))
        text = " \n".join(text_parts)
        nlp = self._spacy_manager.get_model()
        nlp.max_length = self._max_length
        if len(text) > self._max_length:
            logger.warning(
                f"Text length {len(text)} exceeds max_length {self._max_length}."
            )
            text_chunks = [
                text[i : i + self._max_length]
                for i in range(0, len(text), self._max_length)
            ]
            spacy_docs = []
            for chunk in tqdm(
                text_chunks, desc="Processing text as chunks of max_length"
            ):
                spacy_doc = nlp(chunk)
                spacy_docs.append(spacy_doc)
            self._spacy_doc = spacy_docs[0]
            for doc in tqdm(spacy_docs[1:], desc="Merging spacy docs"):
                self._spacy_doc = Doc.from_docs([self._spacy_doc, doc])  # type: ignore
        else:
            self._spacy_doc = nlp(text)
        return self._spacy_doc

    # @lru_cache(maxsize=3)
    def make_each_document_into_spacy_doc(self, id="corpus"):
        if self._corpus is None:
            raise ValueError("Corpus is not set")

        # ! if cached file exists, load it
        cache_dir = Path("cache")
        cache_file = cache_dir / f"spacy_docs_{id}.pkl"
        if cache_file.exists():
            with open(cache_file, "rb") as f:
                spacy_docs, ids = pickle.load(f)
            logger.info("Loaded cached spacy docs and ids.")
            return spacy_docs, ids

        spacy_docs = []
        ids = []
        # Load SpaCy model once outside the loop for efficiency
        nlp = self._spacy_manager.get_model()
        nlp.max_length = self._max_length
        for document in tqdm(
            self._corpus.documents,
            desc="Creating spacy docs",
            disable=len(self._corpus.documents) < 10,
        ):
            text = self.process_text(document.text)
            spacy_doc = nlp(text)
            spacy_docs.append(spacy_doc)
            ids.append(document.id)

        # ! dump spacy_docs, ids to a file for caching with the corpus id
        cache_dir = Path("cache")
        cache_dir.mkdir(exist_ok=True)
        cache_file = cache_dir / f"spacy_docs_{id}.pkl"
        with open(cache_file, "wb") as f:
            pickle.dump((spacy_docs, ids), f)
        return spacy_docs, ids

    def process_text(self, text: str) -> str:
        """
        Process the text by removing unwanted characters and normalizing it.
        """
        # Remove unwanted characters
        text = preprocessing.replace.urls(text)
        text = preprocessing.replace.emails(text)
        text = preprocessing.replace.phone_numbers(text)
        text = preprocessing.replace.currency_symbols(text)
        text = preprocessing.replace.hashtags(text)
        text = preprocessing.replace.numbers(text)

        # lowercase the text
        text = text.lower()
        return text

    # @lru_cache(maxsize=3)
    def process_tokens(self, id="corpus"):
        """
        Process tokens in the spacy document and extract relevant information.
        """

        # ! if cached file exists, load it
        cache_dir = Path("cache")
        cache_file = cache_dir / f"spacy_doc_{id}.pkl"
        if cache_file.exists():
            with open(cache_file, "rb") as f:
                spacy_doc, results = pickle.load(f)
            logger.info("Loaded cached spacy doc and results.")
            return spacy_doc, results

        spacy_doc = self.make_spacy_doc()
        logger.info("Spacy doc created.")

        n_cores = multiprocessing.cpu_count()

        def process_token(token):
            if token.is_stop or token.is_digit or token.is_punct or token.is_space:
                return None
            if token.like_url or token.like_num or token.like_email:
                return None
            if len(token.text) < 3 or token.text.isupper():
                return None
            return {
                "text": token.text,
                "lemma": token.lemma_,
                "pos": token.pos_,
                "pos_": token.pos,
                "word": token.lemma_,
                "sentiment": token.sentiment,
                "tag": token.tag_,
                "dep": token.dep_,
                "prob": token.prob,
                "idx": token.idx,
            }

        tokens = list(spacy_doc)
        _lemma = {}
        _pos = {}
        _pos_ = {}
        _word = {}
        _sentiment = {}
        _tag = {}
        _dep = {}
        _prob = {}
        _idx = {}
        with ThreadPoolExecutor() as executor:
            futures = {executor.submit(process_token, token): token for token in tokens}
            with tqdm(
                total=len(futures),
                desc=f"Processing tokens (parallel, {n_cores} cores)",
            ) as pbar:
                for future in as_completed(futures):
                    result = future.result()
                    if result is not None:
                        _lemma[result["text"]] = result["lemma"]
                        _pos[result["text"]] = result["pos"]
                        _pos_[result["text"]] = result["pos_"]
                        _word[result["text"]] = result["word"]
                        _sentiment[result["text"]] = result["sentiment"]
                        _tag = result["tag"]
                        _dep = result["dep"]
                        _prob = result["prob"]
                        _idx = result["idx"]
                    pbar.update(1)
        logger.info("Token processing complete.")
        results = {
            "lemma": _lemma,
            "pos": _pos,
            "pos_": _pos_,
            "word": _word,
            "sentiment": _sentiment,
            "tag": _tag,
            "dep": _dep,
            "prob": _prob,
            "idx": _idx,
        }
        # ! dump spacy_doc, results to a file for caching with the corpus id
        cache_dir = Path("cache")
        cache_dir.mkdir(exist_ok=True)
        cache_file = cache_dir / f"spacy_doc_{id}.pkl"
        with open(cache_file, "wb") as f:
            pickle.dump((spacy_doc, results), f)

        return spacy_doc, results

    def map_spacy_doc(self):
        spacy_doc, results = self.process_tokens(self._corpus.id if self._corpus else None)
        self._spacy_doc = spacy_doc
        self._lemma = results["lemma"]
        self._pos = results["pos"]
        self._pos_ = results["pos_"]
        self._word = results["word"]
        self._sentiment = results["sentiment"]
        self._tag = results["tag"]
        self._dep = results["dep"]
        self._prob = results["prob"]
        self._idx = results["idx"]

    def common_words(self, index=10):
        self.map_spacy_doc()
        _words = {}
        for key, value in self._word.items():
            _words[value] = _words.get(value, 0) + 1
        return sorted(_words.items(), key=operator.itemgetter(1), reverse=True)[:index]

    def common_nouns(self, index=10):
        self.map_spacy_doc()
        _words = {}
        for key, value in self._word.items():
            if self._pos.get(key, None) == "NOUN":
                _words[value] = _words.get(value, 0) + 1
        return sorted(_words.items(), key=operator.itemgetter(1), reverse=True)[:index]

    def common_verbs(self, index=10):
        self.map_spacy_doc()
        _words = {}
        for key, value in self._word.items():
            if self._pos.get(key, None) == "VERB":
                _words[value] = _words.get(value, 0) + 1
        return sorted(_words.items(), key=operator.itemgetter(1), reverse=True)[:index]

    def print_coding_dictionary(self, num=10, top_n=5):
        """Prints a coding dictionary based on common verbs, attributes, and dimensions.
        "CATEGORY" is the common verb
        "PROPERTY" is the common nouns associated with the verb
        "DIMENSION" is the common adjectives/adverbs/verbs associated with the property
        Args:
            num (int, optional): Number of common verbs to consider. Defaults to 10.
            top_n (int, optional): Number of top attributes and dimensions to consider for each verb. Defaults to 5.

        """
        self.map_spacy_doc()
        output = []
        coding_dict = []
        output.append(("CATEGORY", "PROPERTY", "DIMENSION"))
        verbs = self.common_verbs(num)
        _verbs = []
        for verb, freq in verbs:
            _verbs.append(verb)
        for verb, freq in verbs:
            for attribute, f2 in self.attributes(verb, top_n):
                for dimension, f3 in self.dimensions(attribute, top_n):
                    if dimension not in _verbs:
                        output.append((verb, attribute, dimension))
                        coding_dict.append(f"{verb} > {attribute} > {dimension}")
        # Add coding_dict to corpus metadata
        if self._corpus is not None:
            self._corpus.metadata["coding_dict"] = coding_dict
        print("\n---Coding Dictionary---")
        QRUtils.print_table(output)
        print("---------------------------\n")
        return output

    def sentences_with_common_nouns(self, index=10):
        self.map_spacy_doc()
        _nouns = self.common_nouns(index)
        # Let's look at the sentences
        sents = []
        # Ensure self._spacy_doc is initialized
        if self._spacy_doc is None:
            self._spacy_doc = self.make_spacy_doc()
        # the "sents" property returns spans
        # spans have indices into the original string
        # where each index value represents a token
        for span in self._spacy_doc.sents:
            # go from the start to the end of each span, returning each token in the sentence
            # combine each token using join()
            sent = " ".join(
                self._spacy_doc[i].text for i in range(span.start, span.end)
            ).strip()
            for noun, freq in _nouns:
                if noun in sent:
                    sents.append(sent)
        return sents

    def spans_with_common_nouns(self, word):
        self.map_spacy_doc()
        # Let's look at the sentences
        spans = []
        # the "sents" property returns spans
        # spans have indices into the original string
        # where each index value represents a token
        if self._spacy_doc is None:
            self._spacy_doc = self.make_spacy_doc()
        for span in self._spacy_doc.sents:
            # go from the start to the end of each span, returning each token in the sentence
            # combine each token using join()
            for token in span.text.split():
                if word in self._word.get(token, " "):
                    spans.append(span)
        return spans

    def dimensions(self, word, index=3):
        self.map_spacy_doc()
        _spans = self.spans_with_common_nouns(word)
        _ad = {}
        for span in _spans:
            for token in span.text.split():
                if self._pos.get(token, None) == "ADJ":
                    _ad[self._word.get(token)] = _ad.get(self._word.get(token), 0) + 1
                if self._pos.get(token, None) == "ADV":
                    _ad[self._word.get(token)] = _ad.get(self._word.get(token), 0) + 1
                if self._pos.get(token, None) == "VERB":
                    _ad[self._word.get(token)] = _ad.get(self._word.get(token), 0) + 1
        return sorted(_ad.items(), key=operator.itemgetter(1), reverse=True)[:index]

    def attributes(self, word, index=3):
        self.map_spacy_doc()
        _spans = self.spans_with_common_nouns(word)
        _ad = {}
        for span in _spans:
            for token in span.text.split():
                if self._pos.get(token, None) == "NOUN" and word not in self._word.get(
                    token, ""
                ):
                    _ad[self._word.get(token)] = _ad.get(self._word.get(token), 0) + 1
                    # if self._pos.get(token, None) == 'VERB':
                    # _ad[self._word.get(token)] = _ad.get(self._word.get(token), 0) + 1
        return sorted(_ad.items(), key=operator.itemgetter(1), reverse=True)[:index]

    # filter documents in the corpus based on metadata
    def filter_documents(self, metadata_key, metadata_value, mcp=False, id_column="id"):
        """
        Filter documents in the corpus based on metadata.
        If id_column exists in self._corpus.df, filter the DataFrame to match filtered documents' ids.
        """
        # * filter does not require spacy mapping
        # self.map_spacy_doc()
        if self._corpus is None:
            raise ValueError("Corpus is not set")
        filtered_documents = []
        for document in tqdm(
            self._corpus.documents,
            desc="Filtering documents",
            disable=len(self._corpus.documents) < 10,
        ):
            meta_val = document.metadata.get(metadata_key)
            # Check meta_val is not None and is iterable (str, list, tuple, set)
            if meta_val is not None and isinstance(meta_val, (str, list, tuple, set)):
                if metadata_value in meta_val:
                    filtered_documents.append(document)
            # Check document.id and document.text are not None and are str
            if isinstance(document.id, str) and metadata_value in document.id:
                filtered_documents.append(document)
            if isinstance(document.name, str) and metadata_value in document.name:
                filtered_documents.append(document)
        self._corpus.documents = filtered_documents

        # Check for id_column in self._corpus.df and filter df if present
        if (
            hasattr(self._corpus, "df")
            and self._corpus.df is not None
            and id_column in self._corpus.df.columns
        ):
            logger.info(f"id_column '{id_column}' exists in DataFrame.")
            filtered_ids = [doc.id for doc in filtered_documents]
            # Convert id_column to string before comparison
            self._corpus.df = self._corpus.df[
                self._corpus.df[id_column]
                .astype(str)
                .isin([str(i) for i in filtered_ids])
            ]
        else:
            logger.warning(f"id_column '{id_column}' does not exist in DataFrame.")

        if mcp:
            return f"Filtered {len(filtered_documents)} documents with {metadata_key} containing {metadata_value}"
        return filtered_documents

    # get the count of documents in the corpus
    def document_count(self):
        """
        Get the count of documents in the corpus.
        """
        if self._corpus is None:
            raise ValueError("Corpus is not set")
        return len(self._corpus.documents)

    def generate_summary(self, weight=10):
        """[summary]

        Args:
            weight (int, optional): Parameter for summary generation weight. Defaults to 10.

        Returns:
            list: A list of summary lines
        """
        self.map_spacy_doc()
        words = self.common_words()
        spans = []
        ct = 0
        for key, value in words:
            ct += 1
            if ct > weight:
                continue
            for span in self.spans_with_common_nouns(key):
                spans.append(span.text)
        if self._corpus is not None:
            self._corpus.metadata["summary"] = list(
                dict.fromkeys(spans)
            )  # remove duplicates
        return list(dict.fromkeys(spans))  # remove duplicates

    def print_categories(self, spacy_doc=None, num=10):
        self.map_spacy_doc()
        bot = self._spacy_doc._.to_bag_of_terms( # type: ignore
            by="lemma_",
            weighting="freq",
            ngs=(1, 2, 3),
            ents=True,
            ncs=True,
            dedupe=True,
        )
        categories = sorted(bot.items(), key=lambda x: x[1], reverse=True)[:num]
        output = []
        to_return = []
        print("\n---Categories with count---")
        output.append(("CATEGORY", "WEIGHT"))
        for category, count in categories:
            output.append((category, str(count)))
            to_return.append(category)
        QRUtils.print_table(output)
        print("---------------------------\n")
        if self._corpus is not None:
            self._corpus.metadata["categories"] = output
        return to_return

    def category_basket(self, num=10):
        item_basket = []
        spacy_docs, ids = self.make_each_document_into_spacy_doc()
        for spacy_doc in spacy_docs:
            item_basket.append(self.print_categories(spacy_doc, num))
        documents_copy = []
        documents = self._corpus.documents if self._corpus is not None else []
        # add cateogies to respective documents
        for i, document in enumerate(documents):
            if i < len(item_basket):
                document.metadata["categories"] = item_basket[i]
                documents_copy.append(document)
        # update the corpus with the new documents
        if self._corpus is not None:
            self._corpus.documents = documents_copy
        return item_basket
        # Example return:
        # [['GT', 'Strauss', 'coding', 'ground', 'theory', 'seminal', 'Corbin', 'code',
        # 'structure', 'ground theory'], ['category', 'theory', 'comparison', 'incident',
        # 'GT', 'structure', 'coding', 'Classical', 'Grounded', 'Theory'],
        # ['theory', 'GT', 'evaluation'], ['open', 'coding', 'category', 'QRMine',
        # 'open coding', 'researcher', 'step', 'data', 'break', 'analytically'],
        # ['ground', 'theory', 'GT', 'ground theory'], ['category', 'comparison', 'incident',
        # 'category comparison', 'Theory', 'theory']]

    def category_association(self, num=10):
        """Generates the support for itemsets

        Args:
            num (int, optional): number of categories to generate for each doc in corpus. . Defaults to 10.
        """
        self.map_spacy_doc()
        basket = self.category_basket(num)
        te = TransactionEncoder()
        te_ary = te.fit(basket).transform(basket)
        df = pd.DataFrame(te_ary, columns=te.columns_)  # type: ignore
        _apriori = apriori(df, min_support=0.6, use_colnames=True)
        # Example
        #    support      itemsets
        # 0  0.666667          (GT)
        # 1  0.833333      (theory)
        # 2  0.666667  (theory, GT)
        documents_copy = []
        documents = self._corpus.documents if self._corpus is not None else []
        # TODO (Change) Add association rules to each document
        for i, document in enumerate(documents):
            if i < len(basket):
                # ! fix document.metadata["association_rules"] = _apriori #TODO This is a corpus metadata, not a document one
                documents_copy.append(document)
        # Add to corpus metadata
        if self._corpus is not None:
            self._corpus.metadata["association_rules"] = _apriori
        # Update the corpus with the new documents
        if self._corpus is not None:
            self._corpus.documents = documents_copy
        return _apriori

corpus property writable

Get the corpus.

initial_document_count property

Get the initial document count.

lang property writable

Get the language of the corpus.

max_length property writable

Get the maximum length of the corpus.

category_association(num=10)

Generates the support for itemsets

Parameters:

Name Type Description Default
num int

number of categories to generate for each doc in corpus. . Defaults to 10.

10
Source code in src/crisp_t/text.py
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
def category_association(self, num=10):
    """Generates the support for itemsets

    Args:
        num (int, optional): number of categories to generate for each doc in corpus. . Defaults to 10.
    """
    self.map_spacy_doc()
    basket = self.category_basket(num)
    te = TransactionEncoder()
    te_ary = te.fit(basket).transform(basket)
    df = pd.DataFrame(te_ary, columns=te.columns_)  # type: ignore
    _apriori = apriori(df, min_support=0.6, use_colnames=True)
    # Example
    #    support      itemsets
    # 0  0.666667          (GT)
    # 1  0.833333      (theory)
    # 2  0.666667  (theory, GT)
    documents_copy = []
    documents = self._corpus.documents if self._corpus is not None else []
    # TODO (Change) Add association rules to each document
    for i, document in enumerate(documents):
        if i < len(basket):
            # ! fix document.metadata["association_rules"] = _apriori #TODO This is a corpus metadata, not a document one
            documents_copy.append(document)
    # Add to corpus metadata
    if self._corpus is not None:
        self._corpus.metadata["association_rules"] = _apriori
    # Update the corpus with the new documents
    if self._corpus is not None:
        self._corpus.documents = documents_copy
    return _apriori

document_count()

Get the count of documents in the corpus.

Source code in src/crisp_t/text.py
515
516
517
518
519
520
521
def document_count(self):
    """
    Get the count of documents in the corpus.
    """
    if self._corpus is None:
        raise ValueError("Corpus is not set")
    return len(self._corpus.documents)

filter_documents(metadata_key, metadata_value, mcp=False, id_column='id')

Filter documents in the corpus based on metadata. If id_column exists in self._corpus.df, filter the DataFrame to match filtered documents' ids.

Source code in src/crisp_t/text.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
def filter_documents(self, metadata_key, metadata_value, mcp=False, id_column="id"):
    """
    Filter documents in the corpus based on metadata.
    If id_column exists in self._corpus.df, filter the DataFrame to match filtered documents' ids.
    """
    # * filter does not require spacy mapping
    # self.map_spacy_doc()
    if self._corpus is None:
        raise ValueError("Corpus is not set")
    filtered_documents = []
    for document in tqdm(
        self._corpus.documents,
        desc="Filtering documents",
        disable=len(self._corpus.documents) < 10,
    ):
        meta_val = document.metadata.get(metadata_key)
        # Check meta_val is not None and is iterable (str, list, tuple, set)
        if meta_val is not None and isinstance(meta_val, (str, list, tuple, set)):
            if metadata_value in meta_val:
                filtered_documents.append(document)
        # Check document.id and document.text are not None and are str
        if isinstance(document.id, str) and metadata_value in document.id:
            filtered_documents.append(document)
        if isinstance(document.name, str) and metadata_value in document.name:
            filtered_documents.append(document)
    self._corpus.documents = filtered_documents

    # Check for id_column in self._corpus.df and filter df if present
    if (
        hasattr(self._corpus, "df")
        and self._corpus.df is not None
        and id_column in self._corpus.df.columns
    ):
        logger.info(f"id_column '{id_column}' exists in DataFrame.")
        filtered_ids = [doc.id for doc in filtered_documents]
        # Convert id_column to string before comparison
        self._corpus.df = self._corpus.df[
            self._corpus.df[id_column]
            .astype(str)
            .isin([str(i) for i in filtered_ids])
        ]
    else:
        logger.warning(f"id_column '{id_column}' does not exist in DataFrame.")

    if mcp:
        return f"Filtered {len(filtered_documents)} documents with {metadata_key} containing {metadata_value}"
    return filtered_documents

generate_summary(weight=10)

[summary]

Parameters:

Name Type Description Default
weight int

Parameter for summary generation weight. Defaults to 10.

10

Returns:

Name Type Description
list

A list of summary lines

Source code in src/crisp_t/text.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
def generate_summary(self, weight=10):
    """[summary]

    Args:
        weight (int, optional): Parameter for summary generation weight. Defaults to 10.

    Returns:
        list: A list of summary lines
    """
    self.map_spacy_doc()
    words = self.common_words()
    spans = []
    ct = 0
    for key, value in words:
        ct += 1
        if ct > weight:
            continue
        for span in self.spans_with_common_nouns(key):
            spans.append(span.text)
    if self._corpus is not None:
        self._corpus.metadata["summary"] = list(
            dict.fromkeys(spans)
        )  # remove duplicates
    return list(dict.fromkeys(spans))  # remove duplicates

print_coding_dictionary(num=10, top_n=5)

Prints a coding dictionary based on common verbs, attributes, and dimensions. "CATEGORY" is the common verb "PROPERTY" is the common nouns associated with the verb "DIMENSION" is the common adjectives/adverbs/verbs associated with the property Args: num (int, optional): Number of common verbs to consider. Defaults to 10. top_n (int, optional): Number of top attributes and dimensions to consider for each verb. Defaults to 5.

Source code in src/crisp_t/text.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
def print_coding_dictionary(self, num=10, top_n=5):
    """Prints a coding dictionary based on common verbs, attributes, and dimensions.
    "CATEGORY" is the common verb
    "PROPERTY" is the common nouns associated with the verb
    "DIMENSION" is the common adjectives/adverbs/verbs associated with the property
    Args:
        num (int, optional): Number of common verbs to consider. Defaults to 10.
        top_n (int, optional): Number of top attributes and dimensions to consider for each verb. Defaults to 5.

    """
    self.map_spacy_doc()
    output = []
    coding_dict = []
    output.append(("CATEGORY", "PROPERTY", "DIMENSION"))
    verbs = self.common_verbs(num)
    _verbs = []
    for verb, freq in verbs:
        _verbs.append(verb)
    for verb, freq in verbs:
        for attribute, f2 in self.attributes(verb, top_n):
            for dimension, f3 in self.dimensions(attribute, top_n):
                if dimension not in _verbs:
                    output.append((verb, attribute, dimension))
                    coding_dict.append(f"{verb} > {attribute} > {dimension}")
    # Add coding_dict to corpus metadata
    if self._corpus is not None:
        self._corpus.metadata["coding_dict"] = coding_dict
    print("\n---Coding Dictionary---")
    QRUtils.print_table(output)
    print("---------------------------\n")
    return output

process_text(text)

Process the text by removing unwanted characters and normalizing it.

Source code in src/crisp_t/text.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def process_text(self, text: str) -> str:
    """
    Process the text by removing unwanted characters and normalizing it.
    """
    # Remove unwanted characters
    text = preprocessing.replace.urls(text)
    text = preprocessing.replace.emails(text)
    text = preprocessing.replace.phone_numbers(text)
    text = preprocessing.replace.currency_symbols(text)
    text = preprocessing.replace.hashtags(text)
    text = preprocessing.replace.numbers(text)

    # lowercase the text
    text = text.lower()
    return text

process_tokens(id='corpus')

Process tokens in the spacy document and extract relevant information.

Source code in src/crisp_t/text.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
def process_tokens(self, id="corpus"):
    """
    Process tokens in the spacy document and extract relevant information.
    """

    # ! if cached file exists, load it
    cache_dir = Path("cache")
    cache_file = cache_dir / f"spacy_doc_{id}.pkl"
    if cache_file.exists():
        with open(cache_file, "rb") as f:
            spacy_doc, results = pickle.load(f)
        logger.info("Loaded cached spacy doc and results.")
        return spacy_doc, results

    spacy_doc = self.make_spacy_doc()
    logger.info("Spacy doc created.")

    n_cores = multiprocessing.cpu_count()

    def process_token(token):
        if token.is_stop or token.is_digit or token.is_punct or token.is_space:
            return None
        if token.like_url or token.like_num or token.like_email:
            return None
        if len(token.text) < 3 or token.text.isupper():
            return None
        return {
            "text": token.text,
            "lemma": token.lemma_,
            "pos": token.pos_,
            "pos_": token.pos,
            "word": token.lemma_,
            "sentiment": token.sentiment,
            "tag": token.tag_,
            "dep": token.dep_,
            "prob": token.prob,
            "idx": token.idx,
        }

    tokens = list(spacy_doc)
    _lemma = {}
    _pos = {}
    _pos_ = {}
    _word = {}
    _sentiment = {}
    _tag = {}
    _dep = {}
    _prob = {}
    _idx = {}
    with ThreadPoolExecutor() as executor:
        futures = {executor.submit(process_token, token): token for token in tokens}
        with tqdm(
            total=len(futures),
            desc=f"Processing tokens (parallel, {n_cores} cores)",
        ) as pbar:
            for future in as_completed(futures):
                result = future.result()
                if result is not None:
                    _lemma[result["text"]] = result["lemma"]
                    _pos[result["text"]] = result["pos"]
                    _pos_[result["text"]] = result["pos_"]
                    _word[result["text"]] = result["word"]
                    _sentiment[result["text"]] = result["sentiment"]
                    _tag = result["tag"]
                    _dep = result["dep"]
                    _prob = result["prob"]
                    _idx = result["idx"]
                pbar.update(1)
    logger.info("Token processing complete.")
    results = {
        "lemma": _lemma,
        "pos": _pos,
        "pos_": _pos_,
        "word": _word,
        "sentiment": _sentiment,
        "tag": _tag,
        "dep": _dep,
        "prob": _prob,
        "idx": _idx,
    }
    # ! dump spacy_doc, results to a file for caching with the corpus id
    cache_dir = Path("cache")
    cache_dir.mkdir(exist_ok=True)
    cache_file = cache_dir / f"spacy_doc_{id}.pkl"
    with open(cache_file, "wb") as f:
        pickle.dump((spacy_doc, results), f)

    return spacy_doc, results

ML

Source code in src/crisp_t/ml.py
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
class ML:
    def __init__(
        self,
        csv: Csv,
    ):
        if not ML_INSTALLED:
            raise ImportError("ML dependencies are not installed.")
        self._csv = csv
        self._epochs = 3
        self._samplesize = 0

    @property
    def csv(self):
        return self._csv

    @property
    def corpus(self):
        return self._csv.corpus

    @csv.setter
    def csv(self, value):
        if isinstance(value, Csv):
            self._csv = value
        else:
            raise ValueError(f"The input belongs to {type(value)} instead of Csv.")

    def get_kmeans(self, number_of_clusters=3, seed=42, verbose=True, mcp=False):
        if self._csv is None:
            raise ValueError(
                "CSV data is not set. Please set self.csv before calling get_kmeans."
            )
        X, _ = self._csv.read_xy("")  # No output variable for clustering
        if X is None:
            raise ValueError(
                "Input features X are None. Cannot perform KMeans clustering."
            )
        kmeans = KMeans(
            n_clusters=number_of_clusters, init="k-means++", random_state=seed
        )
        self._clusters = kmeans.fit_predict(X)
        members = self._get_members(self._clusters, number_of_clusters)
        # Add cluster info to csv to metadata_cluster column
        if self._csv is not None and getattr(self._csv, "df", None) is not None:
            self._csv.df["metadata_cluster"] = self._clusters
        if verbose:
            print("KMeans Cluster Centers:\n", kmeans.cluster_centers_)
            print(
                "KMeans Inertia (Sum of squared distances to closest cluster center):\n",
                kmeans.inertia_,
            )
            if self._csv.corpus is not None:
                self._csv.corpus.metadata["kmeans"] = (
                    f"KMeans clustering with {number_of_clusters} clusters. Inertia: {kmeans.inertia_}"
                )
        # Add members info to corpus metadata
        members_info = "\n".join(
            [
                f"Cluster {i}: {len(members[i])} members"
                for i in range(number_of_clusters)
            ]
        )
        if self._csv.corpus is not None:
            self._csv.corpus.metadata["kmeans_members"] = (
                f"KMeans clustering members:\n{members_info}"
            )
        if mcp:
            return members_info
        return self._clusters, members

    def _get_members(self, clusters, number_of_clusters=3):
        _df = self._csv.df
        self._csv.df = _df
        members = []
        for i in range(number_of_clusters):
            members.append([])
        for i, cluster in enumerate(clusters):
            members[cluster].append(i)
        return members

    def profile(self, members, number_of_clusters=3):
        if self._csv is None:
            raise ValueError(
                "CSV data is not set. Please set self.csv before calling profile."
            )
        _corpus = self._csv.corpus
        _numeric_clusters = ""
        for i in range(number_of_clusters):
            print("Cluster: ", i)
            print("Cluster Length: ", len(members[i]))
            print("Cluster Members")
            if self._csv is not None and getattr(self._csv, "df", None) is not None:
                print(self._csv.df.iloc[members[i], :])
                print("Centroids")
                print(self._csv.df.iloc[members[i], :].mean(axis=0))
                _numeric_clusters += f"Cluster {i} with {len(members[i])} members\n has the following centroids (mean values):\n"
                _numeric_clusters += (
                    f"{self._csv.df.iloc[members[i], :].mean(axis=0)}\n"
                )
            else:
                print("DataFrame (self._csv.df) is not set.")
        if _corpus is not None:
            _corpus.metadata["numeric_clusters"] = _numeric_clusters
            self._csv.corpus = _corpus
        return members

    def get_nnet_predictions(self, y: str, mcp=False):
        """
        Extended: Handles binary (BCELoss) and multi-class (CrossEntropyLoss).
        Returns list of predicted original class labels.
        """
        if ML_INSTALLED is False:
            logger.info(
                "ML dependencies are not installed. Please install them by ```pip install crisp-t[ml] to use ML features."
            )
            return None

        if self._csv is None:
            raise ValueError(
                "CSV data is not set. Please set self.csv before calling profile."
            )
        _corpus = self._csv.corpus

        X_np, Y_raw, X, Y = self._process_xy(y=y)

        unique_classes = np.unique(Y_raw)
        num_classes = unique_classes.size
        if num_classes < 2:
            raise ValueError(f"Need at least 2 classes; found {num_classes}.")

        vnum = X_np.shape[1]

        # Binary path
        if num_classes == 2:
            # Map to {0.0,1.0} for BCELoss if needed
            mapping_applied = False
            class_mapping = {}
            inverse_mapping = {}
            # Ensure deterministic order
            sorted_classes = sorted(unique_classes.tolist())
            if not (sorted_classes == [0, 1] or sorted_classes == [0.0, 1.0]):
                class_mapping = {sorted_classes[0]: 0.0, sorted_classes[1]: 1.0}
                inverse_mapping = {v: k for k, v in class_mapping.items()}
                Y_mapped = np.vectorize(class_mapping.get)(Y_raw).astype(np.float32)
                mapping_applied = True
            else:
                Y_mapped = Y_raw.astype(np.float32)

            model = NeuralNet(vnum)
            try:
                criterion = nn.BCELoss()  # type: ignore
                optimizer = optim.Adam(model.parameters(), lr=0.001)  # type: ignore

                X_tensor = torch.from_numpy(X_np)  # type: ignore
                y_tensor = torch.from_numpy(Y_mapped.astype(np.float32)).view(-1, 1)  # type: ignore

                dataset = TensorDataset(X_tensor, y_tensor)  # type: ignore
                dataloader = DataLoader(dataset, batch_size=32, shuffle=True)  # type: ignore
            except Exception as e:
                logger.error(f"Error occurred while creating DataLoader: {e}")
                return None

            for _ in range(self._epochs):
                for batch_X, batch_y in dataloader:
                    optimizer.zero_grad()
                    outputs = model(batch_X)
                    loss = criterion(outputs, batch_y)
                    if torch.isnan(loss):  # type: ignore
                        raise RuntimeError("NaN loss encountered.")
                    loss.backward()
                    optimizer.step()

            # Predictions
            bin_preds_internal = None
            if torch:
                with torch.no_grad():
                    probs = model(torch.from_numpy(X_np)).view(-1).cpu().numpy()
                bin_preds_internal = (probs >= 0.5).astype(int)

            if mapping_applied:
                preds = [inverse_mapping[float(p)] for p in bin_preds_internal]  # type: ignore
                y_eval = np.vectorize(class_mapping.get)(Y_raw).astype(int)
                preds_eval = bin_preds_internal
            else:
                preds = bin_preds_internal.tolist()  # type: ignore
                y_eval = Y_mapped.astype(int)
                preds_eval = bin_preds_internal

            accuracy = (preds_eval == y_eval).sum() / len(y_eval)
            print(
                f"\nPredicting {y} with {X.shape[1]} features for {self._epochs} epochs gave an accuracy (convergence): {accuracy*100:.2f}%\n"
            )
            if _corpus is not None:
                _corpus.metadata["nnet_predictions"] = (
                    f"Predicting {y} with {X.shape[1]} features for {self._epochs} epochs gave an accuracy (convergence): {accuracy*100:.2f}%"
                )
            if mcp:
                return f"Predicting {y} with {X.shape[1]} features for {self._epochs} epochs gave an accuracy (convergence): {accuracy*100:.2f}%"
            return preds

        # Multi-class path
        # Map original classes to indices
        sorted_classes = sorted(unique_classes.tolist())
        class_to_idx = {c: i for i, c in enumerate(sorted_classes)}
        idx_to_class = {i: c for c, i in class_to_idx.items()}
        Y_idx = np.vectorize(class_to_idx.get)(Y_raw).astype(np.int64)

        model = MultiClassNet(vnum, num_classes)
        criterion = nn.CrossEntropyLoss()  # type: ignore
        optimizer = optim.Adam(model.parameters(), lr=0.001)  # type: ignore

        X_tensor = torch.from_numpy(X_np)  # type: ignore
        y_tensor = torch.from_numpy(Y_idx)  # type: ignore

        dataset = TensorDataset(X_tensor, y_tensor)  # type: ignore
        dataloader = DataLoader(dataset, batch_size=32, shuffle=True)  # type: ignore

        for _ in range(self._epochs):
            for batch_X, batch_y in dataloader:
                optimizer.zero_grad()
                logits = model(batch_X)
                loss = criterion(logits, batch_y)
                if torch.isnan(loss):  # type: ignore
                    raise RuntimeError("NaN loss encountered.")
                loss.backward()
                optimizer.step()

        with torch.no_grad():  # type: ignore
            logits_full = model(torch.from_numpy(X_np))  # type: ignore
            pred_indices = torch.argmax(logits_full, dim=1).cpu().numpy()  # type: ignore

        preds = [idx_to_class[i] for i in pred_indices]
        accuracy = (pred_indices == Y_idx).sum() / len(Y_idx)
        print(
            f"\nPredicting {y} with {X.shape[1]} features for {self._epochs} gave an accuracy (convergence): {accuracy*100:.2f}%\n"
        )
        if _corpus is not None:
            _corpus.metadata["nnet_predictions"] = (
                f"Predicting {y} with {X.shape[1]} features for {self._epochs} gave an accuracy (convergence): {accuracy*100:.2f}%"
            )
        if mcp:
            return f"Predicting {y} with {X.shape[1]} features for {self._epochs} gave an accuracy (convergence): {accuracy*100:.2f}%"
        return preds

    def _convert_to_binary(self, Y):
        unique_values = np.unique(Y)
        if len(unique_values) != 2:
            logger.warning(
                "Target variable has more than two unique values."
            )
            # convert unique_values[0] to 0, rest to 1
            mapping = {val: (0 if val == unique_values[0] else 1) for val in unique_values}
        else:
            mapping = {unique_values[0]: 0, unique_values[1]: 1}
        Y_binary = np.vectorize(mapping.get)(Y)
        print(f"Converted target variable to binary using mapping: {mapping}")
        return Y_binary

    def svm_confusion_matrix(self, y: str, test_size=0.25, random_state=0, mcp=False):
        """Generate confusion matrix for SVM

        Returns:
            [list] -- [description]
        """
        X_np, Y_raw, X, Y = self._process_xy(y=y)
        Y = self._convert_to_binary(Y)
        X_train, X_test, y_train, y_test = train_test_split(
            X, Y, test_size=test_size, random_state=random_state
        )
        sc = StandardScaler()
        # Issue #22
        y_test = y_test.astype("int")
        y_train = y_train.astype("int")
        X_train = sc.fit_transform(X_train)
        X_test = sc.transform(X_test)
        classifier = SVC(kernel="linear", random_state=0)
        classifier.fit(X_train, y_train)
        y_pred = classifier.predict(X_test)
        # Issue #22
        y_pred = y_pred.astype("int")
        _confusion_matrix = confusion_matrix(y_test, y_pred)
        print(f"Confusion Matrix for SVM predicting {y}:\n{_confusion_matrix}")
        # Output
        # [[2 0]
        #  [2 0]]
        if self._csv.corpus is not None:
            self._csv.corpus.metadata["svm_confusion_matrix"] = (
                f"Confusion Matrix for SVM predicting {y}:\n{self.format_confusion_matrix_to_human_readable(_confusion_matrix)}"
            )

        if mcp:
            return f"Confusion Matrix for SVM predicting {y}:\n{self.format_confusion_matrix_to_human_readable(_confusion_matrix)}"

        return _confusion_matrix

    def format_confusion_matrix_to_human_readable(
        self, confusion_matrix: np.ndarray
    ) -> str:
        """Format the confusion matrix to a human-readable string.

        Args:
            confusion_matrix (np.ndarray): The confusion matrix to format.

        Returns:
            str: The formatted confusion matrix with true positive, false positive, true negative, and false negative counts.
        """
        tn, fp, fn, tp = confusion_matrix.ravel()
        return (
            f"True Positive: {tp}\n"
            f"False Positive: {fp}\n"
            f"True Negative: {tn}\n"
            f"False Negative: {fn}\n"
        )

    # https://stackoverflow.com/questions/45419203/python-numpy-extracting-a-row-from-an-array
    def knn_search(self, y: str, n=3, r=3, mcp=False):
        X_np, Y_raw, X, Y = self._process_xy(y=y)
        kdt = KDTree(X_np, leaf_size=2, metric="euclidean")
        dist, ind = kdt.query(X_np[r - 1 : r, :], k=n)
        # Display results as human readable (1-based)
        ind = (ind + 1).tolist()  # Convert to 1-based index
        dist = dist.tolist()
        print(
            f"\nKNN search for {y} (n={n}, record no: {r}): {ind} with distances {dist}\n"
        )
        if self._csv.corpus is not None:
            self._csv.corpus.metadata["knn_search"] = (
                f"KNN search for {y} (n={n}, record no: {r}): {ind} with distances {dist}"
            )
        if mcp:
            return f"KNN search for {y} (n={n}, record no: {r}): {ind} with distances {dist}"
        return dist, ind

    def _process_xy(self, y: str, oversample=False, one_hot_encode_all=False):
        X, Y = self._csv.prepare_data(
            y=y, oversample=oversample, one_hot_encode_all=one_hot_encode_all
        )
        if X is None or Y is None:
            raise ValueError("prepare_data returned None for X or Y.")

        # To numpy float32
        X_np = (
            X.to_numpy(dtype=np.float32)
            if hasattr(X, "to_numpy")
            else np.asarray(X, dtype=np.float32)
        )
        Y_raw = Y.to_numpy() if hasattr(Y, "to_numpy") else np.asarray(Y)

        # Handle NaNs
        if np.isnan(X_np).any():
            raise ValueError("NaN detected in feature matrix.")
        if np.isnan(Y_raw.astype(float, copy=False)).any():
            raise ValueError("NaN detected in target vector.")

        return X_np, Y_raw, X, Y

    def get_decision_tree_classes(
        self, y: str, top_n=5, test_size=0.5, random_state=1, mcp=False
    ):
        X_np, Y_raw, X, Y = self._process_xy(y=y)
        Y_raw = self._convert_to_binary(Y_raw)
        X_train, X_test, y_train, y_test = train_test_split(
            X_np, Y_raw, test_size=test_size, random_state=random_state
        )

        # print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
        # print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")

        # Train a RandomForestClassifier
        clf = RandomForestClassifier(n_estimators=100, random_state=42)
        clf.fit(X_train, y_train)

        # Compute permutation importance
        results = permutation_importance(
            clf, X_test, y_test, n_repeats=10, random_state=42
        )

        # classifier = DecisionTreeClassifier(random_state=random_state) # type: ignore
        # classifier.fit(X_train, y_train)
        y_pred = clf.predict(X_test)
        _confusion_matrix = confusion_matrix(y_test, y_pred)
        print(
            f"Confusion Matrix for Decision Tree predicting {y}:\n{_confusion_matrix}"
        )
        # Output
        # [[2 0]
        #  [2 0]]

        accuracy = accuracy_score(y_test, y_pred)
        print(f"\nAccuracy: {accuracy}\n")

        # Retrieve feature importance scores
        importance = results.importances_mean

        # Get indices of top N important features
        top_n_indices = np.argsort(importance)[-top_n:][::-1]

        # Display feature importance
        print(f"==== Top {top_n} important features ====\n")
        _importance = ""
        for i, v in enumerate(top_n_indices):
            print(f"Feature: {X.columns[v]}, Score: {importance[v]:.5f}")
            _importance += f"Feature: {X.columns[v]}, Score: {importance[v]:.5f}\n"

        if self._csv.corpus is not None:
            self._csv.corpus.metadata["decision_tree_accuracy"] = (
                f"Decision Tree accuracy for predicting {y}: {accuracy*100:.2f}%"
            )
            self._csv.corpus.metadata["decision_tree_confusion_matrix"] = (
                f"Confusion Matrix for Decision Tree predicting {y}:\n{self.format_confusion_matrix_to_human_readable(_confusion_matrix)}"
            )
            self._csv.corpus.metadata["decision_tree_feature_importance"] = _importance
        if mcp:
            return f"""
            Confusion Matrix for Decision Tree predicting {y}:\n{self.format_confusion_matrix_to_human_readable(_confusion_matrix)}\nTop {top_n} important features:\n{_importance}
            Accuracy: {accuracy*100:.2f}%
            """
        return _confusion_matrix, importance

    def get_xgb_classes(
        self, y: str, oversample=False, test_size=0.25, random_state=0, mcp=False
    ):
        try:
            from xgboost import XGBClassifier  # type: ignore
        except ImportError:
            raise ImportError(
                "XGBoost is not installed. Please install it via `pip install crisp-t[xg]`."
            )
        X_np, Y_raw, X, Y = self._process_xy(y=y)
        if ML_INSTALLED:
            # ValueError: Invalid classes inferred from unique values of `y`.  Expected: [0 1], got [1 2]
            # convert y to binary
            Y_binary = (Y_raw == 1).astype(int)
            X_train, X_test, y_train, y_test = train_test_split(
                X_np, Y_binary, test_size=test_size, random_state=random_state
            )
            classifier = XGBClassifier(use_label_encoder=False, eval_metric="logloss")  # type: ignore
            classifier.fit(X_train, y_train)
            y_pred = classifier.predict(X_test)
            _confusion_matrix = confusion_matrix(y_test, y_pred)
            print(f"Confusion Matrix for XGBoost predicting {y}:\n{_confusion_matrix}")
            # Output
            # [[2 0]
            #  [2 0]]
            if self._csv.corpus is not None:
                self._csv.corpus.metadata["xgb_confusion_matrix"] = (
                    f"Confusion Matrix for XGBoost predicting {y}:\n{_confusion_matrix}"
                )
            if mcp:
                return f"""
                Confusion Matrix for XGBoost predicting {y}:\n{self.format_confusion_matrix_to_human_readable(_confusion_matrix)}
                """
            return _confusion_matrix
        else:
            raise ImportError("ML dependencies are not installed.")

    def get_apriori(
        self, y: str, min_support=0.9, use_colnames=True, min_threshold=0.5, mcp=False
    ):
        if ML_INSTALLED:
            X_np, Y_raw, X, Y = self._process_xy(y=y, one_hot_encode_all=True)
            frequent_itemsets = apriori(X, min_support=min_support, use_colnames=use_colnames)  # type: ignore
            # rules = association_rules(frequent_itemsets, metric="lift", min_threshold=min_threshold) # type: ignore
            human_readable = tabulate(
                frequent_itemsets.head(10), headers="keys", tablefmt="pretty"  # type: ignore
            )
            if self._csv.corpus is not None:
                self._csv.corpus.metadata["apriori_frequent_itemsets"] = human_readable
            if mcp:
                return f"Frequent itemsets (top 10):\n{human_readable}"
            return frequent_itemsets  # , rules
        else:
            raise ImportError("ML dependencies are not installed.")

    def get_pca(self, y: str, n: int = 3, mcp=False):
        """
        Perform a manual PCA (no sklearn PCA) on the feature matrix for target y.

        Args:
            y (str): Target column name (used only for data preparation).
            n (int): Number of principal components to keep.

        Returns:
            dict: {
                'covariance_matrix': cov_mat,
                'eigenvalues': eig_vals_sorted,
                'eigenvectors': eig_vecs_sorted,
                'explained_variance_ratio': var_exp,
                'cumulative_explained_variance_ratio': cum_var_exp,
                'projection_matrix': matrix_w,
                'transformed': X_pca
            }
        """
        X_np, Y_raw, X, Y = self._process_xy(y=y)
        X_std = StandardScaler().fit_transform(X_np)

        cov_mat = np.cov(X_std.T)
        eig_vals, eig_vecs = np.linalg.eigh(cov_mat)  # symmetric matrix -> eigh

        # Sort eigenvalues (and vectors) descending
        idx = np.argsort(eig_vals)[::-1]
        eig_vals_sorted = eig_vals[idx]
        eig_vecs_sorted = eig_vecs[:, idx]

        factors = X_std.shape[1]
        n = max(1, min(n, factors))

        # Explained variance ratios
        tot = eig_vals_sorted.sum()
        var_exp = (eig_vals_sorted / tot) * 100.0
        cum_var_exp = np.cumsum(var_exp)

        # Projection matrix (first n eigenvectors)
        matrix_w = eig_vecs_sorted[:, :n]

        # Project data
        X_pca = X_std @ matrix_w

        # Optional prints (retain original behavior)
        print("Covariance matrix:\n", cov_mat)
        print("Eigenvalues (desc):\n", eig_vals_sorted)
        print("Explained variance (%):\n", var_exp[:n])
        print("Cumulative explained variance (%):\n", cum_var_exp[:n])
        print("Projection matrix (W):\n", matrix_w)
        print("Transformed (first 5 rows):\n", X_pca[:5])

        result = {
            "covariance_matrix": cov_mat,
            "eigenvalues": eig_vals_sorted,
            "eigenvectors": eig_vecs_sorted,
            "explained_variance_ratio": var_exp,
            "cumulative_explained_variance_ratio": cum_var_exp,
            "projection_matrix": matrix_w,
            "transformed": X_pca,
        }

        if self._csv.corpus is not None:
            self._csv.corpus.metadata["pca"] = (
                f"PCA kept {n} components explaining "
                f"{cum_var_exp[n-1]:.2f}% variance."
            )
        if mcp:
            return (
                f"PCA kept {n} components explaining {cum_var_exp[n-1]:.2f}% variance."
            )
        return result

    def get_regression(self, y: str, mcp=False):
        """
        Perform linear or logistic regression based on the outcome variable type.

        If the outcome is binary, fit a logistic regression model.
        Otherwise, fit a linear regression model.

        Args:
            y (str): Target column name for the regression.

        Returns:
            dict: Regression results including coefficients, intercept, and metrics.
        """
        if ML_INSTALLED is False:
            logger.info(
                "ML dependencies are not installed. Please install them by ```pip install crisp-t[ml] to use ML features."
            )
            return None

        if self._csv is None:
            raise ValueError(
                "CSV data is not set. Please set self.csv before calling get_regression."
            )

        X_np, Y_raw, X, Y = self._process_xy(y=y)

        # Check if outcome is binary (logistic) or continuous (linear)
        unique_values = np.unique(Y_raw)
        num_unique = len(unique_values)

        # Determine if binary classification or regression
        is_binary = num_unique == 2

        if is_binary:
            # Logistic Regression
            print(f"\n=== Logistic Regression for {y} ===")
            print(f"Binary outcome detected with values: {unique_values}")

            model = LogisticRegression(max_iter=1000, random_state=42)
            model.fit(X_np, Y_raw)

            # Predictions
            y_pred = model.predict(X_np)

            # Accuracy
            accuracy = accuracy_score(Y_raw, y_pred)
            print(f"\nAccuracy: {accuracy*100:.2f}%")

            # Coefficients and Intercept
            print(f"\nCoefficients:")
            for i, coef in enumerate(model.coef_[0]):
                feature_name = X.columns[i] if hasattr(X, "columns") else f"Feature_{i}"
                print(f"  {feature_name}: {coef:.5f}")

            print(f"\nIntercept: {model.intercept_[0]:.5f}")

            coef_str = "\n".join(
                [
                    f"  {X.columns[i] if hasattr(X, 'columns') else f'Feature_{i}'}: {coef:.5f}"
                    for i, coef in enumerate(model.coef_[0])
                ]
            )

            # Store in metadata
            if self._csv.corpus is not None:
                self._csv.corpus.metadata["logistic_regression_accuracy"] = (
                    f"Logistic Regression accuracy for predicting {y}: {accuracy*100:.2f}%"
                )
                self._csv.corpus.metadata["logistic_regression_coefficients"] = (
                    f"Coefficients:\n{coef_str}"
                )
                self._csv.corpus.metadata["logistic_regression_intercept"] = (
                    f"Intercept: {model.intercept_[0]:.5f}"
                )

            if mcp:
                return f"""
                Logistic Regression accuracy for predicting {y}: {accuracy*100:.2f}%
                Coefficients:
                {coef_str}
                Intercept: {model.intercept_[0]:.5f}
                """
            return {
                "model_type": "logistic",
                "accuracy": accuracy,
                "coefficients": model.coef_[0],
                "intercept": model.intercept_[0],
                "feature_names": X.columns.tolist() if hasattr(X, "columns") else None,
            }
        else:
            # Linear Regression
            print(f"\n=== Linear Regression for {y} ===")
            print(f"Continuous outcome detected with {num_unique} unique values")

            model = LinearRegression()
            model.fit(X_np, Y_raw)

            # Predictions
            y_pred = model.predict(X_np)

            # Metrics
            mse = mean_squared_error(Y_raw, y_pred)
            r2 = r2_score(Y_raw, y_pred)
            print(f"\nMean Squared Error (MSE): {mse:.5f}")
            print(f"R² Score: {r2:.5f}")

            # Coefficients and Intercept
            print(f"\nCoefficients:")
            for i, coef in enumerate(model.coef_):
                feature_name = X.columns[i] if hasattr(X, "columns") else f"Feature_{i}"
                print(f"  {feature_name}: {coef:.5f}")

            print(f"\nIntercept: {model.intercept_:.5f}")

            coef_str = "\n".join(
                [
                    f"  {X.columns[i] if hasattr(X, 'columns') else f'Feature_{i}'}: {coef:.5f}"
                    for i, coef in enumerate(model.coef_)
                ]
            )

            # Store in metadata
            if self._csv.corpus is not None:
                self._csv.corpus.metadata["linear_regression_mse"] = (
                    f"Linear Regression MSE for predicting {y}: {mse:.5f}"
                )
                self._csv.corpus.metadata["linear_regression_r2"] = (
                    f"Linear Regression R² for predicting {y}: {r2:.5f}"
                )
                self._csv.corpus.metadata["linear_regression_coefficients"] = (
                    f"Coefficients:\n{coef_str}"
                )
                self._csv.corpus.metadata["linear_regression_intercept"] = (
                    f"Intercept: {model.intercept_:.5f}"
                )

            if mcp:
                return f"""
                Linear Regression MSE for predicting {y}: {mse:.5f}
                R²: {r2:.5f}
                Feature Names and Coefficients:
                {coef_str}
                Intercept: {model.intercept_:.5f}
                """
            return {
                "model_type": "linear",
                "mse": mse,
                "r2": r2,
                "coefficients": model.coef_,
                "intercept": model.intercept_,
                "feature_names": X.columns.tolist() if hasattr(X, "columns") else None,
            }

    def get_lstm_predictions(self, y: str, mcp=False):
        """
        Train an LSTM model on text data to predict an outcome variable.
        This tests if the texts converge towards predicting the outcome.

        Args:
            y (str): Name of the outcome variable in the DataFrame
            mcp (bool): If True, return a string format suitable for MCP

        Returns:
            Evaluation metrics as string (if mcp=True) or dict
        """
        if ML_INSTALLED is False:
            logger.error(
                "ML dependencies are not installed. Please install them by ```pip install crisp-t[ml] to use ML features."
            )
            if mcp:
                return "ML dependencies are not installed. Please install with: pip install crisp-t[ml]"
            return None

        if self._csv is None:
            logger.error("CSV data is not set.")
            if mcp:
                return "CSV data is not set. Cannot perform LSTM prediction."
            return None

        _corpus = self._csv.corpus
        if _corpus is None:
            logger.error("Corpus is not available.")
            if mcp:
                return "Corpus is not available. Cannot perform LSTM prediction."
            return None

        # Check if id_column exists
        id_column = "id"
        if not hasattr(self._csv, "df") or self._csv.df is None:
            logger.error("DataFrame is not available in CSV.")
            if mcp:
                return "This tool can be used only if texts and outcome variables align. DataFrame is missing."
            return None

        if id_column not in self._csv.df.columns:
            logger.error(
                f"The id_column '{id_column}' does not exist in the DataFrame."
            )
            if mcp:
                return f"This tool can be used only if texts and outcome variables align. The '{id_column}' column is missing from the DataFrame."
            return None

        # Check if outcome variable exists
        if y not in self._csv.df.columns:
            logger.error(f"The outcome variable '{y}' does not exist in the DataFrame.")
            if mcp:
                return f"The outcome variable '{y}' does not exist in the DataFrame."
            return None

        # Process documents and align with outcome variable
        try:
            # Build vocabulary from all documents
            from collections import Counter

            word_counts = Counter()
            tokenized_docs = []

            for doc in tqdm(_corpus.documents, desc="Tokenizing documents", disable=len(_corpus.documents) < 10):
                # Simple tokenization - split on whitespace and lowercase
                tokens = doc.text.lower().split()
                tokenized_docs.append(tokens)
                word_counts.update(tokens)

            # Create vocabulary with most common words (limit to 10000)
            vocab_size = min(10000, len(word_counts)) + 1  # +1 for padding
            most_common = word_counts.most_common(vocab_size - 1)
            word_to_idx = {
                word: idx + 1 for idx, (word, _) in enumerate(most_common)
            }  # 0 reserved for padding

            # Convert documents to sequences of indices
            max_length = 100  # Maximum sequence length
            sequences = []
            doc_ids = []

            for doc, tokens in tqdm(zip(_corpus.documents, tokenized_docs), total=len(_corpus.documents), desc="Converting to sequences", disable=len(_corpus.documents) < 10):
                # Convert tokens to indices
                seq = [word_to_idx.get(token, 0) for token in tokens]
                # Pad or truncate to max_length
                if len(seq) > max_length:
                    seq = seq[:max_length]
                else:
                    seq = seq + [0] * (max_length - len(seq))
                sequences.append(seq)
                doc_ids.append(doc.id)

            # Align with outcome variable using id column
            df = self._csv.df.set_index(id_column)

            aligned_sequences = []
            aligned_outcomes = []

            df_index_str = list(str(idx) for idx in df.index)
            for doc_id, seq in zip(doc_ids, sequences):
                if doc_id in df_index_str:
                    aligned_sequences.append(seq)
                    # Select y from df where id_column == doc_id, using string comparison
                    matched_row = df.loc[
                        [idx for idx in df.index if str(idx) == str(doc_id)]
                    ]
                    if not matched_row.empty:
                        aligned_outcomes.append(matched_row.iloc[0][y])

            if len(aligned_sequences) == 0:
                logger.error("No documents could be aligned with the outcome variable.")
                if mcp:
                    return "This tool can be used only if texts and outcome variables align. No matching IDs found."
                return None

            # Convert to tensors
            X_tensor = torch.LongTensor(aligned_sequences)  # type: ignore
            y_array = np.array(aligned_outcomes)

            # Handle binary classification
            unique_values = np.unique(y_array)
            num_classes = len(unique_values)

            if num_classes < 2:
                logger.error(
                    f"Need at least 2 classes for classification, found {num_classes}"
                )
                if mcp:
                    return f"Need at least 2 classes for classification, found {num_classes}"
                return None

            # Map to 0/1 for binary classification
            if num_classes == 2:
                class_mapping = {unique_values[0]: 0.0, unique_values[1]: 1.0}
                y_mapped = np.array(
                    [class_mapping[val] for val in y_array], dtype=np.float32
                )
            else:
                # Multi-class not supported in this simple LSTM implementation
                logger.error(
                    "Multi-class classification is not supported for LSTM. Please use binary outcome."
                )
                if mcp:
                    return "Multi-class classification is not supported for LSTM. Please use binary outcome."
                return None

            y_tensor = torch.FloatTensor(y_mapped).view(-1, 1)  # type: ignore

            # Split into train/test
            from sklearn.model_selection import train_test_split

            indices = list(range(len(X_tensor)))
            train_idx, test_idx = train_test_split(
                indices, test_size=0.2, random_state=42
            )

            X_train = X_tensor[train_idx]
            y_train = y_tensor[train_idx]
            X_test = X_tensor[test_idx]
            y_test = y_tensor[test_idx]

            # Create model
            model = SimpleLSTM(vocab_size=vocab_size)  # type: ignore
            criterion = nn.BCELoss()  # type: ignore
            optimizer = optim.Adam(model.parameters(), lr=0.001)  # type: ignore

            # Create data loaders
            train_dataset = TensorDataset(X_train, y_train)  # type: ignore
            train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  # type: ignore

            # Training
            epochs = max(self._epochs, 3)  # Use at least 3 epochs for LSTM
            model.train()
            for epoch in range(epochs):
                total_loss = 0
                for batch_x, batch_y in train_loader:
                    optimizer.zero_grad()
                    predictions = model(batch_x)
                    loss = criterion(predictions, batch_y)
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()

                avg_loss = total_loss / len(train_loader)
                logger.info(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

            # Evaluation
            model.eval()
            with torch.no_grad():  # type: ignore
                train_preds = model(X_train)
                test_preds = model(X_test)

                train_preds_binary = (train_preds >= 0.5).float()
                test_preds_binary = (test_preds >= 0.5).float()

                train_accuracy = (train_preds_binary == y_train).float().mean().item()
                test_accuracy = (test_preds_binary == y_test).float().mean().item()

            # Calculate additional metrics for test set
            y_test_np = y_test.cpu().numpy().flatten()
            test_preds_np = test_preds_binary.cpu().numpy().flatten()

            # Confusion matrix elements
            tp = ((test_preds_np == 1) & (y_test_np == 1)).sum()
            tn = ((test_preds_np == 0) & (y_test_np == 0)).sum()
            fp = ((test_preds_np == 1) & (y_test_np == 0)).sum()
            fn = ((test_preds_np == 0) & (y_test_np == 1)).sum()

            # Calculate precision, recall, F1
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = (
                2 * (precision * recall) / (precision + recall)
                if (precision + recall) > 0
                else 0
            )

            result_msg = (
                f"LSTM Model Evaluation for predicting '{y}':\n"
                f"  Vocabulary size: {vocab_size}\n"
                f"  Training samples: {len(X_train)}, Test samples: {len(X_test)}\n"
                f"  Epochs: {epochs}\n"
                f"  Train accuracy: {train_accuracy*100:.2f}%\n"
                f"  Test accuracy (convergence): {test_accuracy*100:.2f}%\n"
                f"  True Positive: {tp}, False Positive: {fp}, True Negative: {tn}, False Negative: {fn}\n"
                f"  Precision: {precision:.3f}\n"
                f"  Recall: {recall:.3f}\n"
                f"  F1-Score: {f1:.3f}\n"
            )

            print(f"\n{result_msg}")

            # Store in corpus metadata
            if _corpus is not None:
                _corpus.metadata["lstm_predictions"] = result_msg

            if mcp:
                return result_msg

            return {
                "vocab_size": vocab_size,
                "train_samples": len(X_train),
                "test_samples": len(X_test),
                "epochs": epochs,
                "train_accuracy": train_accuracy,
                "test_accuracy": test_accuracy,
                "true_positive": tp,
                "false_positive": fp,
                "true_negative": tn,
                "false_negative": fn,
                "precision": precision,
                "recall": recall,
                "f1_score": f1,
            }

        except Exception as e:
            logger.error(f"Error in LSTM prediction: {e}")
            if mcp:
                return f"Error in LSTM prediction: {e}"
            return None

format_confusion_matrix_to_human_readable(confusion_matrix)

Format the confusion matrix to a human-readable string.

Parameters:

Name Type Description Default
confusion_matrix ndarray

The confusion matrix to format.

required

Returns:

Name Type Description
str str

The formatted confusion matrix with true positive, false positive, true negative, and false negative counts.

Source code in src/crisp_t/ml.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
def format_confusion_matrix_to_human_readable(
    self, confusion_matrix: np.ndarray
) -> str:
    """Format the confusion matrix to a human-readable string.

    Args:
        confusion_matrix (np.ndarray): The confusion matrix to format.

    Returns:
        str: The formatted confusion matrix with true positive, false positive, true negative, and false negative counts.
    """
    tn, fp, fn, tp = confusion_matrix.ravel()
    return (
        f"True Positive: {tp}\n"
        f"False Positive: {fp}\n"
        f"True Negative: {tn}\n"
        f"False Negative: {fn}\n"
    )

get_lstm_predictions(y, mcp=False)

Train an LSTM model on text data to predict an outcome variable. This tests if the texts converge towards predicting the outcome.

Parameters:

Name Type Description Default
y str

Name of the outcome variable in the DataFrame

required
mcp bool

If True, return a string format suitable for MCP

False

Returns:

Type Description

Evaluation metrics as string (if mcp=True) or dict

Source code in src/crisp_t/ml.py
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
def get_lstm_predictions(self, y: str, mcp=False):
    """
    Train an LSTM model on text data to predict an outcome variable.
    This tests if the texts converge towards predicting the outcome.

    Args:
        y (str): Name of the outcome variable in the DataFrame
        mcp (bool): If True, return a string format suitable for MCP

    Returns:
        Evaluation metrics as string (if mcp=True) or dict
    """
    if ML_INSTALLED is False:
        logger.error(
            "ML dependencies are not installed. Please install them by ```pip install crisp-t[ml] to use ML features."
        )
        if mcp:
            return "ML dependencies are not installed. Please install with: pip install crisp-t[ml]"
        return None

    if self._csv is None:
        logger.error("CSV data is not set.")
        if mcp:
            return "CSV data is not set. Cannot perform LSTM prediction."
        return None

    _corpus = self._csv.corpus
    if _corpus is None:
        logger.error("Corpus is not available.")
        if mcp:
            return "Corpus is not available. Cannot perform LSTM prediction."
        return None

    # Check if id_column exists
    id_column = "id"
    if not hasattr(self._csv, "df") or self._csv.df is None:
        logger.error("DataFrame is not available in CSV.")
        if mcp:
            return "This tool can be used only if texts and outcome variables align. DataFrame is missing."
        return None

    if id_column not in self._csv.df.columns:
        logger.error(
            f"The id_column '{id_column}' does not exist in the DataFrame."
        )
        if mcp:
            return f"This tool can be used only if texts and outcome variables align. The '{id_column}' column is missing from the DataFrame."
        return None

    # Check if outcome variable exists
    if y not in self._csv.df.columns:
        logger.error(f"The outcome variable '{y}' does not exist in the DataFrame.")
        if mcp:
            return f"The outcome variable '{y}' does not exist in the DataFrame."
        return None

    # Process documents and align with outcome variable
    try:
        # Build vocabulary from all documents
        from collections import Counter

        word_counts = Counter()
        tokenized_docs = []

        for doc in tqdm(_corpus.documents, desc="Tokenizing documents", disable=len(_corpus.documents) < 10):
            # Simple tokenization - split on whitespace and lowercase
            tokens = doc.text.lower().split()
            tokenized_docs.append(tokens)
            word_counts.update(tokens)

        # Create vocabulary with most common words (limit to 10000)
        vocab_size = min(10000, len(word_counts)) + 1  # +1 for padding
        most_common = word_counts.most_common(vocab_size - 1)
        word_to_idx = {
            word: idx + 1 for idx, (word, _) in enumerate(most_common)
        }  # 0 reserved for padding

        # Convert documents to sequences of indices
        max_length = 100  # Maximum sequence length
        sequences = []
        doc_ids = []

        for doc, tokens in tqdm(zip(_corpus.documents, tokenized_docs), total=len(_corpus.documents), desc="Converting to sequences", disable=len(_corpus.documents) < 10):
            # Convert tokens to indices
            seq = [word_to_idx.get(token, 0) for token in tokens]
            # Pad or truncate to max_length
            if len(seq) > max_length:
                seq = seq[:max_length]
            else:
                seq = seq + [0] * (max_length - len(seq))
            sequences.append(seq)
            doc_ids.append(doc.id)

        # Align with outcome variable using id column
        df = self._csv.df.set_index(id_column)

        aligned_sequences = []
        aligned_outcomes = []

        df_index_str = list(str(idx) for idx in df.index)
        for doc_id, seq in zip(doc_ids, sequences):
            if doc_id in df_index_str:
                aligned_sequences.append(seq)
                # Select y from df where id_column == doc_id, using string comparison
                matched_row = df.loc[
                    [idx for idx in df.index if str(idx) == str(doc_id)]
                ]
                if not matched_row.empty:
                    aligned_outcomes.append(matched_row.iloc[0][y])

        if len(aligned_sequences) == 0:
            logger.error("No documents could be aligned with the outcome variable.")
            if mcp:
                return "This tool can be used only if texts and outcome variables align. No matching IDs found."
            return None

        # Convert to tensors
        X_tensor = torch.LongTensor(aligned_sequences)  # type: ignore
        y_array = np.array(aligned_outcomes)

        # Handle binary classification
        unique_values = np.unique(y_array)
        num_classes = len(unique_values)

        if num_classes < 2:
            logger.error(
                f"Need at least 2 classes for classification, found {num_classes}"
            )
            if mcp:
                return f"Need at least 2 classes for classification, found {num_classes}"
            return None

        # Map to 0/1 for binary classification
        if num_classes == 2:
            class_mapping = {unique_values[0]: 0.0, unique_values[1]: 1.0}
            y_mapped = np.array(
                [class_mapping[val] for val in y_array], dtype=np.float32
            )
        else:
            # Multi-class not supported in this simple LSTM implementation
            logger.error(
                "Multi-class classification is not supported for LSTM. Please use binary outcome."
            )
            if mcp:
                return "Multi-class classification is not supported for LSTM. Please use binary outcome."
            return None

        y_tensor = torch.FloatTensor(y_mapped).view(-1, 1)  # type: ignore

        # Split into train/test
        from sklearn.model_selection import train_test_split

        indices = list(range(len(X_tensor)))
        train_idx, test_idx = train_test_split(
            indices, test_size=0.2, random_state=42
        )

        X_train = X_tensor[train_idx]
        y_train = y_tensor[train_idx]
        X_test = X_tensor[test_idx]
        y_test = y_tensor[test_idx]

        # Create model
        model = SimpleLSTM(vocab_size=vocab_size)  # type: ignore
        criterion = nn.BCELoss()  # type: ignore
        optimizer = optim.Adam(model.parameters(), lr=0.001)  # type: ignore

        # Create data loaders
        train_dataset = TensorDataset(X_train, y_train)  # type: ignore
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  # type: ignore

        # Training
        epochs = max(self._epochs, 3)  # Use at least 3 epochs for LSTM
        model.train()
        for epoch in range(epochs):
            total_loss = 0
            for batch_x, batch_y in train_loader:
                optimizer.zero_grad()
                predictions = model(batch_x)
                loss = criterion(predictions, batch_y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            avg_loss = total_loss / len(train_loader)
            logger.info(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

        # Evaluation
        model.eval()
        with torch.no_grad():  # type: ignore
            train_preds = model(X_train)
            test_preds = model(X_test)

            train_preds_binary = (train_preds >= 0.5).float()
            test_preds_binary = (test_preds >= 0.5).float()

            train_accuracy = (train_preds_binary == y_train).float().mean().item()
            test_accuracy = (test_preds_binary == y_test).float().mean().item()

        # Calculate additional metrics for test set
        y_test_np = y_test.cpu().numpy().flatten()
        test_preds_np = test_preds_binary.cpu().numpy().flatten()

        # Confusion matrix elements
        tp = ((test_preds_np == 1) & (y_test_np == 1)).sum()
        tn = ((test_preds_np == 0) & (y_test_np == 0)).sum()
        fp = ((test_preds_np == 1) & (y_test_np == 0)).sum()
        fn = ((test_preds_np == 0) & (y_test_np == 1)).sum()

        # Calculate precision, recall, F1
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = (
            2 * (precision * recall) / (precision + recall)
            if (precision + recall) > 0
            else 0
        )

        result_msg = (
            f"LSTM Model Evaluation for predicting '{y}':\n"
            f"  Vocabulary size: {vocab_size}\n"
            f"  Training samples: {len(X_train)}, Test samples: {len(X_test)}\n"
            f"  Epochs: {epochs}\n"
            f"  Train accuracy: {train_accuracy*100:.2f}%\n"
            f"  Test accuracy (convergence): {test_accuracy*100:.2f}%\n"
            f"  True Positive: {tp}, False Positive: {fp}, True Negative: {tn}, False Negative: {fn}\n"
            f"  Precision: {precision:.3f}\n"
            f"  Recall: {recall:.3f}\n"
            f"  F1-Score: {f1:.3f}\n"
        )

        print(f"\n{result_msg}")

        # Store in corpus metadata
        if _corpus is not None:
            _corpus.metadata["lstm_predictions"] = result_msg

        if mcp:
            return result_msg

        return {
            "vocab_size": vocab_size,
            "train_samples": len(X_train),
            "test_samples": len(X_test),
            "epochs": epochs,
            "train_accuracy": train_accuracy,
            "test_accuracy": test_accuracy,
            "true_positive": tp,
            "false_positive": fp,
            "true_negative": tn,
            "false_negative": fn,
            "precision": precision,
            "recall": recall,
            "f1_score": f1,
        }

    except Exception as e:
        logger.error(f"Error in LSTM prediction: {e}")
        if mcp:
            return f"Error in LSTM prediction: {e}"
        return None

get_nnet_predictions(y, mcp=False)

Extended: Handles binary (BCELoss) and multi-class (CrossEntropyLoss). Returns list of predicted original class labels.

Source code in src/crisp_t/ml.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def get_nnet_predictions(self, y: str, mcp=False):
    """
    Extended: Handles binary (BCELoss) and multi-class (CrossEntropyLoss).
    Returns list of predicted original class labels.
    """
    if ML_INSTALLED is False:
        logger.info(
            "ML dependencies are not installed. Please install them by ```pip install crisp-t[ml] to use ML features."
        )
        return None

    if self._csv is None:
        raise ValueError(
            "CSV data is not set. Please set self.csv before calling profile."
        )
    _corpus = self._csv.corpus

    X_np, Y_raw, X, Y = self._process_xy(y=y)

    unique_classes = np.unique(Y_raw)
    num_classes = unique_classes.size
    if num_classes < 2:
        raise ValueError(f"Need at least 2 classes; found {num_classes}.")

    vnum = X_np.shape[1]

    # Binary path
    if num_classes == 2:
        # Map to {0.0,1.0} for BCELoss if needed
        mapping_applied = False
        class_mapping = {}
        inverse_mapping = {}
        # Ensure deterministic order
        sorted_classes = sorted(unique_classes.tolist())
        if not (sorted_classes == [0, 1] or sorted_classes == [0.0, 1.0]):
            class_mapping = {sorted_classes[0]: 0.0, sorted_classes[1]: 1.0}
            inverse_mapping = {v: k for k, v in class_mapping.items()}
            Y_mapped = np.vectorize(class_mapping.get)(Y_raw).astype(np.float32)
            mapping_applied = True
        else:
            Y_mapped = Y_raw.astype(np.float32)

        model = NeuralNet(vnum)
        try:
            criterion = nn.BCELoss()  # type: ignore
            optimizer = optim.Adam(model.parameters(), lr=0.001)  # type: ignore

            X_tensor = torch.from_numpy(X_np)  # type: ignore
            y_tensor = torch.from_numpy(Y_mapped.astype(np.float32)).view(-1, 1)  # type: ignore

            dataset = TensorDataset(X_tensor, y_tensor)  # type: ignore
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True)  # type: ignore
        except Exception as e:
            logger.error(f"Error occurred while creating DataLoader: {e}")
            return None

        for _ in range(self._epochs):
            for batch_X, batch_y in dataloader:
                optimizer.zero_grad()
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                if torch.isnan(loss):  # type: ignore
                    raise RuntimeError("NaN loss encountered.")
                loss.backward()
                optimizer.step()

        # Predictions
        bin_preds_internal = None
        if torch:
            with torch.no_grad():
                probs = model(torch.from_numpy(X_np)).view(-1).cpu().numpy()
            bin_preds_internal = (probs >= 0.5).astype(int)

        if mapping_applied:
            preds = [inverse_mapping[float(p)] for p in bin_preds_internal]  # type: ignore
            y_eval = np.vectorize(class_mapping.get)(Y_raw).astype(int)
            preds_eval = bin_preds_internal
        else:
            preds = bin_preds_internal.tolist()  # type: ignore
            y_eval = Y_mapped.astype(int)
            preds_eval = bin_preds_internal

        accuracy = (preds_eval == y_eval).sum() / len(y_eval)
        print(
            f"\nPredicting {y} with {X.shape[1]} features for {self._epochs} epochs gave an accuracy (convergence): {accuracy*100:.2f}%\n"
        )
        if _corpus is not None:
            _corpus.metadata["nnet_predictions"] = (
                f"Predicting {y} with {X.shape[1]} features for {self._epochs} epochs gave an accuracy (convergence): {accuracy*100:.2f}%"
            )
        if mcp:
            return f"Predicting {y} with {X.shape[1]} features for {self._epochs} epochs gave an accuracy (convergence): {accuracy*100:.2f}%"
        return preds

    # Multi-class path
    # Map original classes to indices
    sorted_classes = sorted(unique_classes.tolist())
    class_to_idx = {c: i for i, c in enumerate(sorted_classes)}
    idx_to_class = {i: c for c, i in class_to_idx.items()}
    Y_idx = np.vectorize(class_to_idx.get)(Y_raw).astype(np.int64)

    model = MultiClassNet(vnum, num_classes)
    criterion = nn.CrossEntropyLoss()  # type: ignore
    optimizer = optim.Adam(model.parameters(), lr=0.001)  # type: ignore

    X_tensor = torch.from_numpy(X_np)  # type: ignore
    y_tensor = torch.from_numpy(Y_idx)  # type: ignore

    dataset = TensorDataset(X_tensor, y_tensor)  # type: ignore
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)  # type: ignore

    for _ in range(self._epochs):
        for batch_X, batch_y in dataloader:
            optimizer.zero_grad()
            logits = model(batch_X)
            loss = criterion(logits, batch_y)
            if torch.isnan(loss):  # type: ignore
                raise RuntimeError("NaN loss encountered.")
            loss.backward()
            optimizer.step()

    with torch.no_grad():  # type: ignore
        logits_full = model(torch.from_numpy(X_np))  # type: ignore
        pred_indices = torch.argmax(logits_full, dim=1).cpu().numpy()  # type: ignore

    preds = [idx_to_class[i] for i in pred_indices]
    accuracy = (pred_indices == Y_idx).sum() / len(Y_idx)
    print(
        f"\nPredicting {y} with {X.shape[1]} features for {self._epochs} gave an accuracy (convergence): {accuracy*100:.2f}%\n"
    )
    if _corpus is not None:
        _corpus.metadata["nnet_predictions"] = (
            f"Predicting {y} with {X.shape[1]} features for {self._epochs} gave an accuracy (convergence): {accuracy*100:.2f}%"
        )
    if mcp:
        return f"Predicting {y} with {X.shape[1]} features for {self._epochs} gave an accuracy (convergence): {accuracy*100:.2f}%"
    return preds

get_pca(y, n=3, mcp=False)

Perform a manual PCA (no sklearn PCA) on the feature matrix for target y.

Parameters:

Name Type Description Default
y str

Target column name (used only for data preparation).

required
n int

Number of principal components to keep.

3

Returns:

Name Type Description
dict

{ 'covariance_matrix': cov_mat, 'eigenvalues': eig_vals_sorted, 'eigenvectors': eig_vecs_sorted, 'explained_variance_ratio': var_exp, 'cumulative_explained_variance_ratio': cum_var_exp, 'projection_matrix': matrix_w, 'transformed': X_pca

}

Source code in src/crisp_t/ml.py
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
def get_pca(self, y: str, n: int = 3, mcp=False):
    """
    Perform a manual PCA (no sklearn PCA) on the feature matrix for target y.

    Args:
        y (str): Target column name (used only for data preparation).
        n (int): Number of principal components to keep.

    Returns:
        dict: {
            'covariance_matrix': cov_mat,
            'eigenvalues': eig_vals_sorted,
            'eigenvectors': eig_vecs_sorted,
            'explained_variance_ratio': var_exp,
            'cumulative_explained_variance_ratio': cum_var_exp,
            'projection_matrix': matrix_w,
            'transformed': X_pca
        }
    """
    X_np, Y_raw, X, Y = self._process_xy(y=y)
    X_std = StandardScaler().fit_transform(X_np)

    cov_mat = np.cov(X_std.T)
    eig_vals, eig_vecs = np.linalg.eigh(cov_mat)  # symmetric matrix -> eigh

    # Sort eigenvalues (and vectors) descending
    idx = np.argsort(eig_vals)[::-1]
    eig_vals_sorted = eig_vals[idx]
    eig_vecs_sorted = eig_vecs[:, idx]

    factors = X_std.shape[1]
    n = max(1, min(n, factors))

    # Explained variance ratios
    tot = eig_vals_sorted.sum()
    var_exp = (eig_vals_sorted / tot) * 100.0
    cum_var_exp = np.cumsum(var_exp)

    # Projection matrix (first n eigenvectors)
    matrix_w = eig_vecs_sorted[:, :n]

    # Project data
    X_pca = X_std @ matrix_w

    # Optional prints (retain original behavior)
    print("Covariance matrix:\n", cov_mat)
    print("Eigenvalues (desc):\n", eig_vals_sorted)
    print("Explained variance (%):\n", var_exp[:n])
    print("Cumulative explained variance (%):\n", cum_var_exp[:n])
    print("Projection matrix (W):\n", matrix_w)
    print("Transformed (first 5 rows):\n", X_pca[:5])

    result = {
        "covariance_matrix": cov_mat,
        "eigenvalues": eig_vals_sorted,
        "eigenvectors": eig_vecs_sorted,
        "explained_variance_ratio": var_exp,
        "cumulative_explained_variance_ratio": cum_var_exp,
        "projection_matrix": matrix_w,
        "transformed": X_pca,
    }

    if self._csv.corpus is not None:
        self._csv.corpus.metadata["pca"] = (
            f"PCA kept {n} components explaining "
            f"{cum_var_exp[n-1]:.2f}% variance."
        )
    if mcp:
        return (
            f"PCA kept {n} components explaining {cum_var_exp[n-1]:.2f}% variance."
        )
    return result

get_regression(y, mcp=False)

Perform linear or logistic regression based on the outcome variable type.

If the outcome is binary, fit a logistic regression model. Otherwise, fit a linear regression model.

Parameters:

Name Type Description Default
y str

Target column name for the regression.

required

Returns:

Name Type Description
dict

Regression results including coefficients, intercept, and metrics.

Source code in src/crisp_t/ml.py
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
def get_regression(self, y: str, mcp=False):
    """
    Perform linear or logistic regression based on the outcome variable type.

    If the outcome is binary, fit a logistic regression model.
    Otherwise, fit a linear regression model.

    Args:
        y (str): Target column name for the regression.

    Returns:
        dict: Regression results including coefficients, intercept, and metrics.
    """
    if ML_INSTALLED is False:
        logger.info(
            "ML dependencies are not installed. Please install them by ```pip install crisp-t[ml] to use ML features."
        )
        return None

    if self._csv is None:
        raise ValueError(
            "CSV data is not set. Please set self.csv before calling get_regression."
        )

    X_np, Y_raw, X, Y = self._process_xy(y=y)

    # Check if outcome is binary (logistic) or continuous (linear)
    unique_values = np.unique(Y_raw)
    num_unique = len(unique_values)

    # Determine if binary classification or regression
    is_binary = num_unique == 2

    if is_binary:
        # Logistic Regression
        print(f"\n=== Logistic Regression for {y} ===")
        print(f"Binary outcome detected with values: {unique_values}")

        model = LogisticRegression(max_iter=1000, random_state=42)
        model.fit(X_np, Y_raw)

        # Predictions
        y_pred = model.predict(X_np)

        # Accuracy
        accuracy = accuracy_score(Y_raw, y_pred)
        print(f"\nAccuracy: {accuracy*100:.2f}%")

        # Coefficients and Intercept
        print(f"\nCoefficients:")
        for i, coef in enumerate(model.coef_[0]):
            feature_name = X.columns[i] if hasattr(X, "columns") else f"Feature_{i}"
            print(f"  {feature_name}: {coef:.5f}")

        print(f"\nIntercept: {model.intercept_[0]:.5f}")

        coef_str = "\n".join(
            [
                f"  {X.columns[i] if hasattr(X, 'columns') else f'Feature_{i}'}: {coef:.5f}"
                for i, coef in enumerate(model.coef_[0])
            ]
        )

        # Store in metadata
        if self._csv.corpus is not None:
            self._csv.corpus.metadata["logistic_regression_accuracy"] = (
                f"Logistic Regression accuracy for predicting {y}: {accuracy*100:.2f}%"
            )
            self._csv.corpus.metadata["logistic_regression_coefficients"] = (
                f"Coefficients:\n{coef_str}"
            )
            self._csv.corpus.metadata["logistic_regression_intercept"] = (
                f"Intercept: {model.intercept_[0]:.5f}"
            )

        if mcp:
            return f"""
            Logistic Regression accuracy for predicting {y}: {accuracy*100:.2f}%
            Coefficients:
            {coef_str}
            Intercept: {model.intercept_[0]:.5f}
            """
        return {
            "model_type": "logistic",
            "accuracy": accuracy,
            "coefficients": model.coef_[0],
            "intercept": model.intercept_[0],
            "feature_names": X.columns.tolist() if hasattr(X, "columns") else None,
        }
    else:
        # Linear Regression
        print(f"\n=== Linear Regression for {y} ===")
        print(f"Continuous outcome detected with {num_unique} unique values")

        model = LinearRegression()
        model.fit(X_np, Y_raw)

        # Predictions
        y_pred = model.predict(X_np)

        # Metrics
        mse = mean_squared_error(Y_raw, y_pred)
        r2 = r2_score(Y_raw, y_pred)
        print(f"\nMean Squared Error (MSE): {mse:.5f}")
        print(f"R² Score: {r2:.5f}")

        # Coefficients and Intercept
        print(f"\nCoefficients:")
        for i, coef in enumerate(model.coef_):
            feature_name = X.columns[i] if hasattr(X, "columns") else f"Feature_{i}"
            print(f"  {feature_name}: {coef:.5f}")

        print(f"\nIntercept: {model.intercept_:.5f}")

        coef_str = "\n".join(
            [
                f"  {X.columns[i] if hasattr(X, 'columns') else f'Feature_{i}'}: {coef:.5f}"
                for i, coef in enumerate(model.coef_)
            ]
        )

        # Store in metadata
        if self._csv.corpus is not None:
            self._csv.corpus.metadata["linear_regression_mse"] = (
                f"Linear Regression MSE for predicting {y}: {mse:.5f}"
            )
            self._csv.corpus.metadata["linear_regression_r2"] = (
                f"Linear Regression R² for predicting {y}: {r2:.5f}"
            )
            self._csv.corpus.metadata["linear_regression_coefficients"] = (
                f"Coefficients:\n{coef_str}"
            )
            self._csv.corpus.metadata["linear_regression_intercept"] = (
                f"Intercept: {model.intercept_:.5f}"
            )

        if mcp:
            return f"""
            Linear Regression MSE for predicting {y}: {mse:.5f}
            R²: {r2:.5f}
            Feature Names and Coefficients:
            {coef_str}
            Intercept: {model.intercept_:.5f}
            """
        return {
            "model_type": "linear",
            "mse": mse,
            "r2": r2,
            "coefficients": model.coef_,
            "intercept": model.intercept_,
            "feature_names": X.columns.tolist() if hasattr(X, "columns") else None,
        }

svm_confusion_matrix(y, test_size=0.25, random_state=0, mcp=False)

Generate confusion matrix for SVM

Returns:

Type Description

[list] -- [description]

Source code in src/crisp_t/ml.py
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def svm_confusion_matrix(self, y: str, test_size=0.25, random_state=0, mcp=False):
    """Generate confusion matrix for SVM

    Returns:
        [list] -- [description]
    """
    X_np, Y_raw, X, Y = self._process_xy(y=y)
    Y = self._convert_to_binary(Y)
    X_train, X_test, y_train, y_test = train_test_split(
        X, Y, test_size=test_size, random_state=random_state
    )
    sc = StandardScaler()
    # Issue #22
    y_test = y_test.astype("int")
    y_train = y_train.astype("int")
    X_train = sc.fit_transform(X_train)
    X_test = sc.transform(X_test)
    classifier = SVC(kernel="linear", random_state=0)
    classifier.fit(X_train, y_train)
    y_pred = classifier.predict(X_test)
    # Issue #22
    y_pred = y_pred.astype("int")
    _confusion_matrix = confusion_matrix(y_test, y_pred)
    print(f"Confusion Matrix for SVM predicting {y}:\n{_confusion_matrix}")
    # Output
    # [[2 0]
    #  [2 0]]
    if self._csv.corpus is not None:
        self._csv.corpus.metadata["svm_confusion_matrix"] = (
            f"Confusion Matrix for SVM predicting {y}:\n{self.format_confusion_matrix_to_human_readable(_confusion_matrix)}"
        )

    if mcp:
        return f"Confusion Matrix for SVM predicting {y}:\n{self.format_confusion_matrix_to_human_readable(_confusion_matrix)}"

    return _confusion_matrix

Copyright (C) 2025 Bell Eapen

This file is part of crisp-t.

crisp-t is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

crisp-t is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with crisp-t. If not, see https://www.gnu.org/licenses/.

Network

A class to represent a network of documents and their relationships.

Source code in src/crisp_t/network.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class Network:
    """
    A class to represent a network of documents and their relationships.
    """

    def __init__(self, corpus: Corpus):
        """
        Initialize the Network with a corpus.

        :param corpus: Corpus object containing documents to be included in the network.
        """
        self._corpus = corpus
        self._cluster = Cluster(corpus)
        self._processed_docs = self._cluster.processed_docs
        self._graph = None

    def cooccurence_network(self, window_size=2):
        self._graph = network.build_cooccurrence_network(
            self._processed_docs, window_size=window_size
        )
        return self._graph

    def similarity_network(self, method="levenshtein"):
        text = Text(self._corpus)
        docs = text.make_spacy_doc()
        data = [sent.text.lower() for sent in docs.sents]
        self._graph = network.build_similarity_network(data, method)
        return self._graph

    def graph_as_dict(self):
        """
        Convert the graph to a dictionary representation.

        :return: Dictionary representation of the graph.
        """
        if self._graph is None:
            raise ValueError(
                "Graph has not been created yet. Call cooccurence_network() first."
            )
        return sorted(self._graph.adjacency())[0]

__init__(corpus)

Initialize the Network with a corpus.

:param corpus: Corpus object containing documents to be included in the network.

Source code in src/crisp_t/network.py
31
32
33
34
35
36
37
38
39
40
def __init__(self, corpus: Corpus):
    """
    Initialize the Network with a corpus.

    :param corpus: Corpus object containing documents to be included in the network.
    """
    self._corpus = corpus
    self._cluster = Cluster(corpus)
    self._processed_docs = self._cluster.processed_docs
    self._graph = None

graph_as_dict()

Convert the graph to a dictionary representation.

:return: Dictionary representation of the graph.

Source code in src/crisp_t/network.py
55
56
57
58
59
60
61
62
63
64
65
def graph_as_dict(self):
    """
    Convert the graph to a dictionary representation.

    :return: Dictionary representation of the graph.
    """
    if self._graph is None:
        raise ValueError(
            "Graph has not been created yet. Call cooccurence_network() first."
        )
    return sorted(self._graph.adjacency())[0]

QRVisualize

Source code in src/crisp_t/visualize.py
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
class QRVisualize:

    def __init__(
        self, corpus: Corpus | None = None, folder_path: str | None = None
    ) -> None:
        # Matplotlib figure components assigned lazily by plotting methods
        self.corpus = corpus
        self.folder_path = folder_path
        self.fig: Figure | None = None
        self.ax: Axes | None = None
        self.sc: PathCollection | None = None
        self.annot: Annotation | None = None
        self.names: list[str] = []
        self.c: np.ndarray | None = None

    def _ensure_columns(
        self, df: pd.DataFrame, required: Iterable[str]
    ) -> pd.DataFrame:
        """Ensure that the DataFrame has the required columns.

        Behavior:
        - If all required columns already exist, return df unchanged.
        - If the DataFrame has exactly the same number of columns as required,
          rename columns positionally to match the required names.
        - Otherwise, raise a ValueError listing the missing columns.
        """
        required = list(required)
        # Fast path: all required columns present
        missing = [col for col in required if col not in df.columns]
        if not missing:
            return df

        # If shape matches, attempt a positional rename
        if len(df.columns) == len(required):
            df = df.copy()
            df.columns = required
            return df

        # Otherwise, cannot satisfy required columns
        raise ValueError(f"Missing required columns: {missing}")

    def _finalize_plot(
        self,
        fig: Figure,
        folder_path: str | None,
        show: bool,
    ) -> Figure:
        if not folder_path:
            folder_path = self.folder_path
        if folder_path:
            output_path = Path(folder_path)
            if output_path.parent:
                output_path.parent.mkdir(parents=True, exist_ok=True)
            fig.savefig(folder_path)
        if show:
            plt.show(block=False)
        else:
            plt.close(fig)
        return fig

    def plot_frequency_distribution_of_words(
        self,
        df: pd.DataFrame | None = None,
        folder_path: str | None = None,
        text_column: str = "Text",
        bins: int = 100,
        show: bool = True,
    ) -> Tuple[Figure, Axes]:
        if df is None:
            try:
                df = pd.DataFrame(self.corpus.visualization["assign_topics"])
            except Exception as e:
                raise ValueError(f"Failed to create DataFrame from corpus: {e}")
        df = self._ensure_columns(df, [text_column])
        doc_lens = df[text_column].dropna().map(len).tolist()
        if not doc_lens:
            raise ValueError("No documents available to plot frequency distribution.")

        fig, ax = plt.subplots(figsize=(16, 7), dpi=160)
        counts, _, _ = ax.hist(doc_lens, bins=bins, color="navy")
        counts = np.asarray(counts)
        if counts.size:
            ax.set_ylim(top=float(counts.max()) * 1.1)

        stats = {
            "Mean": round(np.mean(doc_lens), 2),
            "Median": round(np.median(doc_lens), 2),
            "Stdev": round(np.std(doc_lens), 2),
            "1%ile": round(np.quantile(doc_lens, q=0.01), 2),
            "99%ile": round(np.quantile(doc_lens, q=0.99), 2),
        }
        for idx, (label, value) in enumerate(stats.items()):
            ax.text(
                0.98,
                0.98 - idx * 0.05,
                f"{label}: {value}",
                transform=ax.transAxes,
                ha="right",
                va="top",
                fontsize=11,
            )

        ax.set(
            ylabel="Number of Documents",
            xlabel="Document Word Count",
            title="Distribution of Document Word Counts",
        )
        ax.tick_params(axis="both", labelsize=12)
        if doc_lens:
            ax.set_xlim(left=0, right=max(doc_lens) * 1.05)

        fig = self._finalize_plot(fig, folder_path, show)
        return fig, ax

    def plot_distribution_by_topic(
        self,
        df: pd.DataFrame | None = None,
        folder_path: str | None = None,
        topic_column: str = "Dominant_Topic",
        text_column: str = "Text",
        bins: int = 100,
        show: bool = True,
    ) -> Tuple[Figure, np.ndarray]:
        if df is None:
            try:
                df = pd.DataFrame(self.corpus.visualization["assign_topics"])
            except Exception as e:
                raise ValueError(f"Failed to create DataFrame from corpus: {e}")
        df = self._ensure_columns(df, [topic_column, text_column])
        unique_topics = sorted(df[topic_column].dropna().unique())
        if not unique_topics:
            raise ValueError("No topics found to plot distribution.")

        n_topics = len(unique_topics)
        n_cols = min(3, n_topics)
        n_rows = math.ceil(n_topics / n_cols)
        cols = list(mcolors.TABLEAU_COLORS.values())

        fig, axes = plt.subplots(
            n_rows,
            n_cols,
            figsize=(6 * n_cols, 5 * n_rows),
            dpi=160,
            sharex=True,
            sharey=True,
        )
        if isinstance(axes, np.ndarray):
            axes_flat = axes.flatten().tolist()
        else:
            axes_flat = [axes]

        for idx, topic in enumerate(unique_topics):
            ax = axes_flat[idx]
            topic_series = cast(
                pd.Series,
                df.loc[df[topic_column] == topic, text_column],
            )
            topic_docs = topic_series.dropna()
            doc_lens = topic_docs.map(len).tolist()
            color = cols[idx % len(cols)]
            if doc_lens:
                ax.hist(doc_lens, bins=bins, color=color, alpha=0.7)
                sns.kdeplot(
                    doc_lens,
                    color="black",
                    fill=False,
                    ax=ax.twinx(),
                    warn_singular=False,
                )
            ax.set(xlabel="Document Word Count")
            ax.set_ylabel("Number of Documents", color=color)
            ax.set_title(f"Topic: {topic}", fontdict=dict(size=14, color=color))
            ax.tick_params(axis="y", labelcolor=color, color=color)

        for extra_ax in axes_flat[len(unique_topics) :]:
            extra_ax.set_visible(False)

        fig.tight_layout()
        fig.suptitle(
            "Distribution of Document Word Counts by Dominant Topic",
            fontsize=20,
            y=1.02,
        )

        fig = self._finalize_plot(fig, folder_path, show)
        axes_array = np.array(axes_flat, dtype=object).reshape(n_rows, n_cols)
        return fig, axes_array

    def plot_wordcloud(
        self,
        topics=None,
        folder_path: str | None = None,
        max_words: int = 50,
        show: bool = True,
    ) -> Tuple[Figure, np.ndarray]:
        if not topics:
            try:
                topics = self.corpus.visualization["word_cloud"]
            except Exception as e:
                raise ValueError(f"Failed to retrieve topics from corpus: {e}")
        n_topics = len(topics)
        n_cols = min(3, n_topics)
        n_rows = math.ceil(n_topics / n_cols)
        cols = list(mcolors.TABLEAU_COLORS.values())

        fig, axes = plt.subplots(
            n_rows,
            n_cols,
            figsize=(6 * n_cols, 4 * n_rows),
            sharex=True,
            sharey=True,
        )
        axes_flat = axes.flatten().tolist() if isinstance(axes, np.ndarray) else [axes]

        for idx, (topic_id, words) in enumerate(topics):
            ax = axes_flat[idx]
            topic_words = dict(words)
            color = cols[idx % len(cols)]
            cloud = WordCloud(
                stopwords=STOPWORDS,
                background_color="white",
                width=800,
                height=400,
                max_words=max_words,
                colormap="tab10",
                color_func=lambda *args, color=color, **kwargs: color,
                prefer_horizontal=0.9,
            )
            cloud.generate_from_frequencies(topic_words)
            ax.imshow(cloud)
            ax.set_title(f"Topic {topic_id}", fontdict=dict(size=14))
            ax.axis("off")

        for extra_ax in axes_flat[len(topics) :]:
            extra_ax.set_visible(False)

        fig.tight_layout()

        fig = self._finalize_plot(fig, folder_path, show)
        return fig, np.array(axes_flat).reshape(n_rows, n_cols)

    def plot_top_terms(
        self,
        df: pd.DataFrame | None = None,
        term_column: str = "term",
        frequency_column: str = "frequency",
        top_n: int = 20,
        folder_path: str | None = None,
        ascending: bool = False,
        show: bool = True,
    ) -> Tuple[Figure, Axes]:
        if df is None:
            try:
                df = pd.DataFrame(self.corpus.visualization["assign_topics"])
            except Exception as e:
                raise ValueError(f"Failed to create DataFrame from corpus: {e}")
        if top_n <= 0:
            raise ValueError("top_n must be greater than zero.")

        df = self._ensure_columns(df, [term_column, frequency_column])
        subset = df[[term_column, frequency_column]].dropna()
        if subset.empty:
            raise ValueError("No data available to plot top terms.")

        subset = subset.sort_values(frequency_column, ascending=ascending).head(top_n)
        subset = subset.iloc[::-1]

        fig, ax = plt.subplots(figsize=(10, max(4, top_n * 0.4)))
        ax.barh(subset[term_column], subset[frequency_column], color="steelblue")
        ax.set_xlabel("Frequency")
        ax.set_ylabel("Term")
        ax.set_title("Top Terms by Frequency")
        for idx, value in enumerate(subset[frequency_column]):
            ax.text(value, idx, f" {value}", va="center")
        fig.tight_layout()

        fig = self._finalize_plot(fig, folder_path, show)
        return fig, ax

    def plot_correlation_heatmap(
        self,
        df: pd.DataFrame | None = None,
        columns: Sequence[str] | None = None,
        folder_path: str | None = None,
        cmap: str = "coolwarm",
        show: bool = True,
    ) -> Tuple[Figure, Axes]:
        if df is None:
            try:
                df = pd.DataFrame(self.corpus.visualization["assign_topics"])
            except Exception as e:
                raise ValueError(f"Failed to create DataFrame from corpus: {e}")
        if columns:
            df = self._ensure_columns(df, columns)
            data = df[list(columns)]
        else:
            data = df
        if data.empty:
            raise ValueError("No data available to compute correlation heatmap.")

        numeric_data = data.select_dtypes(include=[np.number])
        if numeric_data.shape[1] < 2:
            raise ValueError(
                "At least two numeric columns are required for correlation heatmap."
            )

        corr = numeric_data.corr()
        fig, ax = plt.subplots(figsize=(8, 6))
        sns.heatmap(corr, ax=ax, cmap=cmap, annot=True, fmt=".2f", square=True)
        ax.set_title("Correlation Heatmap")
        fig.tight_layout()

        fig = self._finalize_plot(fig, folder_path, show)
        return fig, ax

    def plot_importance(
        self,
        topics: Sequence[Tuple[int, Sequence[Tuple[str, float]]]],
        processed_docs: Sequence[Sequence[str]],
        folder_path: str | None = None,
        show: bool = True,
    ) -> Tuple[Figure, np.ndarray]:
        if not topics:
            raise ValueError("No topics provided to plot importance.")
        if not processed_docs:
            raise ValueError("No processed documents provided to plot importance.")

        counter = Counter(word for doc in processed_docs for word in doc)
        rows = []
        for topic_id, words in topics:
            for word, weight in words:
                rows.append(
                    {
                        "word": word,
                        "topic_id": topic_id,
                        "importance": weight,
                        "word_count": counter.get(word, 0),
                    }
                )

        df = pd.DataFrame(rows)
        if df.empty:
            raise ValueError("Unable to build importance DataFrame from inputs.")

        topic_ids = sorted(df["topic_id"].unique())
        n_topics = len(topic_ids)
        n_cols = min(3, n_topics)
        n_rows = math.ceil(n_topics / n_cols)
        cols = list(mcolors.TABLEAU_COLORS.values())

        fig, axes = plt.subplots(
            n_rows,
            n_cols,
            figsize=(7 * n_cols, 5 * n_rows),
            sharey=False,
            dpi=160,
        )
        axes_flat = axes.flatten().tolist() if isinstance(axes, np.ndarray) else [axes]

        for idx, topic_id in enumerate(topic_ids):
            ax = axes_flat[idx]
            subset = df[df["topic_id"] == topic_id]
            color = cols[idx % len(cols)]
            ax.bar(
                subset["word"],
                subset["word_count"],
                color=color,
                width=0.5,
                alpha=0.4,
                label="Word Count",
            )
            ax_twin = ax.twinx()
            ax_twin.plot(
                subset["word"],
                subset["importance"],
                color=color,
                marker="o",
                label="Importance",
            )
            ax.set_title(f"Topic {topic_id}", color=color, fontsize=14)
            ax.set_xlabel("Word")
            ax.set_ylabel("Word Count", color=color)
            ax.tick_params(axis="y", labelcolor=color)
            ax_twin.set_ylabel("Importance", color=color)
            ax_twin.tick_params(axis="y", labelcolor=color)
            ax.set_xticklabels(subset["word"], rotation=30, ha="right")
            ax.legend(loc="upper left")
            ax_twin.legend(loc="upper right")

        for extra_ax in axes_flat[len(topic_ids) :]:
            extra_ax.set_visible(False)

        fig.tight_layout()
        fig.suptitle(
            "Word Count and Importance of Topic Keywords",
            fontsize=20,
            y=1.02,
        )

        fig = self._finalize_plot(fig, folder_path, show)
        return fig, np.array(axes_flat).reshape(n_rows, n_cols)

    def sentence_chart(self, lda_model, text, start=0, end=13, folder_path=None):
        if lda_model is None:
            raise ValueError("LDA model is not provided.")
        corp = text[start:end]
        mycolors = [color for name, color in mcolors.TABLEAU_COLORS.items()]

        fig, axes = plt.subplots(
            end - start, 1, figsize=(20, (end - start) * 0.95), dpi=160
        )
        axes[0].axis("off")
        for i, ax in enumerate(axes):
            try:
                if i > 0:
                    corp_cur = corp[i - 1]
                    topic_percs, wordid_topics, _ = lda_model[corp_cur]
                    word_dominanttopic = [
                        (lda_model.id2word[wd], topic[0]) for wd, topic in wordid_topics
                    ]
                    ax.text(
                        0.01,
                        0.5,
                        "Doc " + str(i - 1) + ": ",
                        verticalalignment="center",
                        fontsize=16,
                        color="black",
                        transform=ax.transAxes,
                        fontweight=700,
                    )

                    # Draw Rectange
                    topic_percs_sorted = sorted(
                        topic_percs, key=lambda x: (x[1]), reverse=True
                    )
                    ax.add_patch(
                        Rectangle(
                            (0.0, 0.05),
                            0.99,
                            0.90,
                            fill=None,
                            alpha=1,
                            color=mycolors[topic_percs_sorted[0][0]],
                            linewidth=2,
                        )
                    )

                    word_pos = 0.06
                    for j, (word, topics) in enumerate(word_dominanttopic):
                        if j < 14:
                            ax.text(
                                word_pos,
                                0.5,
                                word,
                                horizontalalignment="left",
                                verticalalignment="center",
                                fontsize=16,
                                color=mycolors[topics],
                                transform=ax.transAxes,
                                fontweight=700,
                            )
                            word_pos += 0.009 * len(
                                word
                            )  # to move the word for the next iter
                            ax.axis("off")
                    ax.text(
                        word_pos,
                        0.5,
                        ". . .",
                        horizontalalignment="left",
                        verticalalignment="center",
                        fontsize=16,
                        color="black",
                        transform=ax.transAxes,
                    )
            except Exception as e:
                logger.error(f"Error occurred while processing document {i - 1}: {e}")
                continue

        plt.subplots_adjust(wspace=0, hspace=0)
        plt.suptitle(
            "Sentence Topic Coloring for Documents: "
            + str(start)
            + " to "
            + str(end - 2),
            fontsize=22,
            y=0.95,
            fontweight=700,
        )
        plt.tight_layout()
        plt.show(block=False)
        # save
        if folder_path:
            plt.savefig(folder_path)
            plt.close()

    def _cluster_chart(self, lda_model, text, n_topics=3, folder_path=None):
        # Get topic weights
        topic_weights = []
        for i, row_list in enumerate(lda_model[text]):
            topic_weights.append([w for i, w in row_list[0]])

        # Array of topic weights
        arr = pd.DataFrame(topic_weights).fillna(0).values

        # Keep the well separated points (optional)
        arr = arr[np.amax(arr, axis=1) > 0.35]

        # Dominant topic number in each doc
        topic_num = np.argmax(arr, axis=1)

        # tSNE Dimension Reduction
        tsne_model = TSNE(
            n_components=2, verbose=1, random_state=0, angle=0.99, init="pca"
        )
        tsne_lda = tsne_model.fit_transform(arr)

        # Plot
        plt.figure(figsize=(16, 10), dpi=160)
        for i in range(n_topics):
            plt.scatter(
                tsne_lda[topic_num == i, 0],
                tsne_lda[topic_num == i, 1],
                label=str(i),
                alpha=0.5,
            )
        plt.title("t-SNE Clustering of Topics", fontsize=22)
        plt.xlabel("t-SNE Dimension 1", fontsize=16)
        plt.ylabel("t-SNE Dimension 2", fontsize=16)
        plt.legend(title="Topic Number", loc="upper right")
        plt.show(block=False)
        # save
        if folder_path:
            plt.savefig(folder_path)
            plt.close()

    def most_discussed_topics(
        self, lda_model, dominant_topics, topic_percentages, folder_path=None
    ):

        # Distribution of Dominant Topics in Each Document
        df = pd.DataFrame(dominant_topics, columns=["Document_Id", "Dominant_Topic"])
        dominant_topic_in_each_doc = df.groupby("Dominant_Topic").size()
        df_dominant_topic_in_each_doc = dominant_topic_in_each_doc.to_frame(
            name="count"
        ).reset_index()

        # Total Topic Distribution by actual weight
        topic_weightage_by_doc = pd.DataFrame([dict(t) for t in topic_percentages])
        df_topic_weightage_by_doc = (
            topic_weightage_by_doc.sum().to_frame(name="count").reset_index()
        )

        # Top 3 Keywords for each Topic
        topic_top3words = [
            (i, topic)
            for i, topics in lda_model.show_topics(formatted=False)
            for j, (topic, wt) in enumerate(topics)
            if j < 3
        ]

        df_top3words_stacked = pd.DataFrame(
            topic_top3words, columns=["topic_id", "words"]
        )
        df_top3words = df_top3words_stacked.groupby("topic_id").agg(", \n".join)
        df_top3words.reset_index(level=0, inplace=True)

        # Plot
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), dpi=120, sharey=True)

        # Topic Distribution by Dominant Topics
        ax1.bar(
            x="Dominant_Topic",
            height="count",
            data=df_dominant_topic_in_each_doc,
            width=0.5,
            color="firebrick",
        )
        ax1.set_xticks(
            range(df_dominant_topic_in_each_doc.Dominant_Topic.unique().__len__())
        )
        tick_formatter = FuncFormatter(
            lambda x, pos: "Topic "
            + str(x)
            + "\n"
            + df_top3words.loc[df_top3words.topic_id == x, "words"].values[0]  # type: ignore
        )
        ax1.xaxis.set_major_formatter(tick_formatter)
        ax1.set_title("Number of Documents by Dominant Topic", fontdict=dict(size=10))
        ax1.set_ylabel("Number of Documents")
        ax1.set_ylim(0, 1000)

        # Topic Distribution by Topic Weights
        ax2.bar(
            x="index",
            height="count",
            data=df_topic_weightage_by_doc,
            width=0.5,
            color="steelblue",
        )
        ax2.set_xticks(range(df_topic_weightage_by_doc.index.unique().__len__()))
        ax2.xaxis.set_major_formatter(tick_formatter)
        ax2.set_title("Number of Documents by Topic Weightage", fontdict=dict(size=10))

        plt.show(block=False)

        # save
        if folder_path:
            plt.savefig(folder_path)
            plt.close()

    def update_annot(self, ind):
        if self.annot is None or self.sc is None or self.c is None:
            raise RuntimeError("cluster_chart must be called before update_annot.")
        indices_array = np.atleast_1d(ind.get("ind", []))
        if indices_array.size == 0:
            return
        indices = indices_array.astype(int)
        idx = int(indices[0])
        offsets = np.asarray(self.sc.get_offsets())
        pos = offsets[idx]
        annot = self.annot
        annot.xy = (float(pos[0]), float(pos[1]))
        text = "{}, {}".format(
            " ".join(list(map(str, indices))),
            " ".join([self.names[n] for n in indices]),
        )
        annot.set_text(text)
        cmap = plt.get_cmap("RdYlGn")
        norm = mcolors.Normalize(1, 4)
        bbox = annot.get_bbox_patch()
        if bbox is not None:
            try:
                color_value = float(self.c[idx])
            except (TypeError, ValueError):
                color_value = 1.0
            bbox.set_facecolor(cmap(norm(color_value)))
            bbox.set_alpha(0.4)

    def hover(self, event):
        if self.annot is None or self.sc is None or self.fig is None or self.ax is None:
            return
        vis = self.annot.get_visible()
        if event.inaxes == self.ax:
            cont, ind = self.sc.contains(event)
            if cont:
                self.update_annot(ind)
                self.annot.set_visible(True)
                self.fig.canvas.draw_idle()
            elif vis:
                self.annot.set_visible(False)
                self.fig.canvas.draw_idle()

    # https://stackoverflow.com/questions/7908636/how-to-add-hovering-annotations-to-a-plot
    def cluster_chart(self, data, folder_path=None):
        # Scatter plot for Text Cluster Prediction
        self.fig, self.ax = plt.subplots(figsize=(6, 6))
        self.names = list(map(str, data["title"]))
        self.sc = plt.scatter(
            data["x"],
            data["y"],
            c=data["colour"],
            s=36,
            edgecolors="black",
            linewidths=0.75,
        )
        self.c = np.asarray(data["colour"])
        self.annot = self.ax.annotate(
            "",
            xy=(0, 0),
            xytext=(20, 20),
            textcoords="offset points",
            bbox=dict(boxstyle="round", fc="w"),
            arrowprops=dict(arrowstyle="->"),
        )
        self.annot.set_visible(False)
        plt.title("Text Cluster Prediction")
        plt.axis("off")  # Optional: Remove axes for a cleaner look
        plt.colorbar(self.sc, label="Colour")  # Add colorbar if needed
        self.fig.canvas.mpl_connect("motion_notify_event", self.hover)
        plt.show(block=False)
        # save
        if folder_path:
            # annotate with data['title']
            for i, txt in enumerate(data["title"]):
                plt.annotate(
                    txt,
                    (data["x"][i], data["y"][i]),
                    fontsize=8,
                    ha="right",
                    va="bottom",
                )
            plt.savefig(folder_path)
            plt.close()

    def get_lda_viz(
        self,
        lda_model,
        corpus_bow,
        dictionary,
        folder_path: str | None = None,
        mds: str = "tsne",
        lambda_val: float = 0.6,
        show: bool = True,
    ) -> str | None:
        """
        Generate an interactive LDA visualization using pyLDAvis.

        Args:
            lda_model: The trained LDA model
            corpus_bow: Bag of words corpus
            dictionary: Gensim dictionary
            folder_path: Path to save the HTML visualization
            mds: Dimension reduction method ('tsne', 'mmds', or 'pcoa')
            lambda_val: Lambda parameter for relevance metric (default: 0.6).
                       Mettler et al. (2025) performed several experiments to identify
                       the optimal value of λ, which turned out to be 0.6.
            show: Whether to display the visualization

        Returns:
            HTML string of the visualization if successful, None otherwise

        Raises:
            ImportError: If pyLDAvis is not installed
            ValueError: If required inputs are missing
        """
        if not PYLDAVIS_AVAILABLE:
            raise ImportError(
                "pyLDAvis is not installed. Install it with: pip install pyLDAvis"
            )

        if lda_model is None:
            raise ValueError("LDA model is required")
        if corpus_bow is None:
            raise ValueError("Corpus bag of words is required")
        if dictionary is None:
            raise ValueError("Dictionary is required")

        try:
            # Prepare the visualization data
            vis_data = gensimvis.prepare(
                lda_model,
                corpus_bow,
                dictionary,
                mds=mds,
                R=30,
                lambda_step=0.01,
                plot_opts={"xlab": "PC1", "ylab": "PC2"},
            )

            # Save to HTML file if path provided
            if folder_path:
                output_path = Path(folder_path)
                if output_path.parent:
                    output_path.parent.mkdir(parents=True, exist_ok=True)
                pyLDAvis.save_html(vis_data, str(output_path))
                logger.info(f"LDA visualization saved to {output_path}")

            # Return HTML string for embedding or further use
            html_string = pyLDAvis.prepared_data_to_html(vis_data)
            return html_string

        except Exception as e:
            logger.error(f"Error generating LDA visualization: {e}")
            raise

    def draw_tdabm(
        self,
        corpus: Corpus | None = None,
        folder_path: str | None = None,
        show: bool = True,
    ) -> Figure:
        """
        Draw TDABM (Topological Data Analysis Ball Mapper) visualization.

        Creates a 2D graph showing landmark points as circles:
        - Circle size is proportional to the count of points in the ball
        - Circle color represents mean y value (red for low, purple for high)
        - Lines connect landmark points with non-empty intersections

        Based on the algorithm by Rudkin and Dlotko (2024).

        Args:
            corpus: Corpus with 'tdabm' metadata. If None, uses self.corpus
            folder_path: Path to save the figure. If None, uses self.folder_path
            show: Whether to display the plot

        Returns:
            Matplotlib Figure object
        """
        if corpus is None:
            corpus = self.corpus

        if corpus is None:
            raise ValueError("No corpus provided")

        if "tdabm" not in corpus.metadata:
            raise ValueError(
                "Corpus metadata does not contain 'tdabm' data. Run TDABM analysis first."
            )

        tdabm_data = corpus.metadata["tdabm"]
        landmarks = tdabm_data["landmarks"]

        if not landmarks:
            raise ValueError("No landmarks found in TDABM data")

        # Create figure
        fig, ax = plt.subplots(figsize=(12, 10))

        # Collect all landmark locations
        locations = [landmark["location"] for landmark in landmarks]
        counts = [landmark["count"] for landmark in landmarks]
        mean_ys = [landmark["mean_y"] for landmark in landmarks]
        landmark_ids = [landmark["id"] for landmark in landmarks]

        # Perform PCA to reduce to 2 components (PC1, PC2)
        from sklearn.decomposition import PCA

        locations_array = np.array(locations)
        if locations_array.shape[1] < 2:
            # If only 1D, pad with zeros
            locations_array = np.pad(locations_array, ((0, 0), (0, 1)), mode="constant")
        pca = PCA(n_components=2)
        positions = pca.fit_transform(locations_array)

        # Normalize mean_y for color mapping (red=0, purple=max)
        min_y = min(mean_ys)
        max_y = max(mean_ys)

        if max_y - min_y > 0:
            normalized_ys = [(y - min_y) / (max_y - min_y) for y in mean_ys]
        else:
            normalized_ys = [0.5] * len(mean_ys)

        # Create color map: red (0) to green (1)
        colors = []
        for norm_y in normalized_ys:
            # Interpolate from red (1,0,0) to green (0,1,0)
            r = 1.0 - norm_y
            g = norm_y
            b = 0.0
            colors.append((r, g, b))

        # Draw connections first (so they appear behind circles)
        landmark_dict = {lm["id"]: idx for idx, lm in enumerate(landmarks)}

        for i, landmark in enumerate(landmarks):
            for connected_id in landmark["connections"]:
                if connected_id in landmark_dict:
                    j = landmark_dict[connected_id]
                    # Only draw each connection once (avoid duplicates)
                    if i < j:
                        ax.plot(
                            [positions[i, 0], positions[j, 0]],
                            [positions[i, 1], positions[j, 1]],
                            "k-",
                            alpha=0.3,
                            linewidth=1,
                            zorder=1,
                        )

        # Normalize counts for circle sizes (scale for visibility)
        max_count = max(counts)
        min_count = min(counts)

        if max_count > min_count:
            # Scale sizes between 100 and 2000
            sizes = [
                100 + 1900 * (c - min_count) / (max_count - min_count) for c in counts
            ]
        else:
            sizes = [500] * len(counts)

        # Draw circles for landmarks
        scatter = ax.scatter(
            positions[:, 0],
            positions[:, 1],
            s=sizes,
            c=colors,
            alpha=0.6,
            edgecolors="black",
            linewidths=1.5,
            zorder=2,
        )

        # Add count and mean_y as label inside each circle
        for i, (pos, count, mean_y) in enumerate(zip(positions, counts, mean_ys)):
            ax.annotate(
                f"{count}\n{mean_y:.2f}",
                xy=pos,
                xytext=(0, 0),
                textcoords="offset points",
                ha="center",
                va="center",
                fontsize=8,
                fontweight="bold",
                zorder=3,
            )

        # Set labels and title
        x_vars = tdabm_data.get("x_variables", [])
        y_var = tdabm_data.get("y_variable", "y")

        # Axis labels reflect PCA components
        ax.set_xlabel("PC1", fontsize=12)
        ax.set_ylabel("PC2", fontsize=12)

        ax.set_title(
            f"TDABM Visualization\n"
            f'Y variable: {y_var}, Radius: {tdabm_data.get("radius", 0.3)}\n'
            f"Landmarks: {len(landmarks)}",
            fontsize=14,
            fontweight="bold",
        )

        # Add colorbar for mean_y (red to green)
        sm = plt.cm.ScalarMappable(
            cmap=mcolors.LinearSegmentedColormap.from_list(
                "red_green", ["red", "green"]
            ),
            norm=mcolors.Normalize(vmin=min_y, vmax=max_y),
        )
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax)
        cbar.set_label(f"Mean {y_var}", fontsize=12)

        # Add legend for circle sizes
        # Create dummy scatter plots for legend
        legend_counts = [min_count, (min_count + max_count) / 2, max_count]
        legend_sizes = []
        for c in legend_counts:
            if max_count > min_count:
                size = 100 + 1900 * (c - min_count) / (max_count - min_count)
            else:
                size = 500
            legend_sizes.append(size)

        legend_elements = []
        for size, count in zip(legend_sizes, legend_counts):
            legend_elements.append(
                plt.scatter(
                    [],
                    [],
                    s=size,
                    c="gray",
                    alpha=0.6,
                    edgecolors="black",
                    linewidths=1.5,
                    label=f"{int(count)} points",
                )
            )

        ax.legend(
            handles=legend_elements,
            title="Ball Size",
            loc="upper right",
            framealpha=0.9,
        )

        ax.grid(True, alpha=0.3)
        ax.set_aspect("equal", adjustable="box")

        plt.tight_layout()

        return self._finalize_plot(fig, folder_path, show)

    def draw_graph(
        self,
        corpus: Corpus | None = None,
        folder_path: str | None = None,
        show: bool = True,
        layout: str = "spring",
    ) -> Figure:
        """
        Draw graph visualization from corpus metadata.

        Creates a visualization of the graph structure showing documents,
        keywords, clusters, and metadata nodes along with their relationships.

        Args:
            corpus: Corpus with 'graph' metadata. If None, uses self.corpus
            folder_path: Path to save the figure. If None, uses self.folder_path
            show: Whether to display the plot
            layout: Graph layout algorithm ('spring', 'circular', 'kamada_kawai', 'spectral')

        Returns:
            Matplotlib Figure object

        Raises:
            ValueError: If corpus or graph metadata is missing
        """
        if corpus is None:
            corpus = self.corpus

        if corpus is None:
            raise ValueError("No corpus provided")

        if "graph" not in corpus.metadata:
            raise ValueError(
                "Corpus metadata does not contain 'graph' data. Run graph generation first."
            )

        graph_data = corpus.metadata["graph"]
        nodes = graph_data["nodes"]
        edges = graph_data["edges"]

        if not nodes:
            raise ValueError("No nodes found in graph data")

        # Create NetworkX graph
        G = nx.Graph()

        # Add nodes with their labels (store as maps keyed by node id)
        node_labels: dict[str, str] = {}
        node_color_map_by_id: dict[str, str] = {}
        node_size_map_by_id: dict[str, float] = {}

        # Color mapping for different node types
        color_map = {
            "document": "#FF6B6B",  # Red
            "keyword": "#4ECDC4",  # Teal
            "cluster": "#95E1D3",  # Light green
            "metadata": "#FFD93D",  # Yellow
        }

        for node in nodes:
            node_id = str(node.get("id"))
            label = node.get("label", "metadata")
            properties = node.get("properties", {})

            G.add_node(node_id, label=label, **properties)

            # Set node label (use name property if available)
            if "name" in properties and properties["name"]:
                node_labels[node_id] = str(properties["name"])
            else:
                # For keywords, remove the "keyword:" prefix
                if node_id.startswith("keyword:"):
                    node_labels[node_id] = node_id.replace("keyword:", "")
                elif node_id.startswith("cluster:"):
                    node_labels[node_id] = f"C{node_id.replace('cluster:', '')}"
                elif node_id.startswith("metadata:"):
                    node_labels[node_id] = "M"
                else:
                    node_labels[node_id] = node_id

            # Set node color based on type
            node_color_map_by_id[node_id] = color_map.get(label, "#CCCCCC")

            # Set node size based on type (documents larger)
            if label == "document":
                node_size_map_by_id[node_id] = 800.0
            elif label == "keyword":
                node_size_map_by_id[node_id] = 500.0
            elif label == "cluster":
                node_size_map_by_id[node_id] = 600.0
            else:
                node_size_map_by_id[node_id] = 400.0

        # Add edges
        for edge in edges:
            source = str(edge.get("source"))
            target = str(edge.get("target"))
            # If edge introduces unknown nodes, add with default properties
            if source not in G:
                G.add_node(source, label="metadata")
                node_labels[source] = (
                    source if not source.startswith("metadata:") else "M"
                )
                node_color_map_by_id[source] = color_map.get("metadata", "#CCCCCC")
                node_size_map_by_id[source] = 400.0
            if target not in G:
                G.add_node(target, label="metadata")
                node_labels[target] = (
                    target if not target.startswith("metadata:") else "M"
                )
                node_color_map_by_id[target] = color_map.get("metadata", "#CCCCCC")
                node_size_map_by_id[target] = 400.0
            G.add_edge(source, target)

        # Create figure
        fig, ax = plt.subplots(figsize=(16, 12))

        # Choose layout algorithm
        if layout == "spring":
            pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
        elif layout == "circular":
            pos = nx.circular_layout(G)
        elif layout == "kamada_kawai":
            pos = nx.kamada_kawai_layout(G)
        elif layout == "spectral":
            pos = nx.spectral_layout(G)
        else:
            pos = nx.spring_layout(G, seed=42)

        # Draw edges first (so they appear behind nodes)
        nx.draw_networkx_edges(
            G,
            pos,
            ax=ax,
            edge_color="#CCCCCC",
            width=1.5,
            alpha=0.6,
        )

        # Build aligned arrays for node attributes
        nodelist = list(G.nodes())
        node_colors = [node_color_map_by_id.get(n, "#CCCCCC") for n in nodelist]
        node_sizes = np.asarray(
            [float(node_size_map_by_id.get(n, 400.0)) for n in nodelist]
        )
        # Draw nodes with explicit nodelist
        nx.draw_networkx_nodes(
            G,
            pos,
            nodelist=nodelist,
            ax=ax,
            node_color=node_colors,
            node_size=node_sizes,
            alpha=0.9,
            edgecolors="black",
            linewidths=1.5,
        )

        # Draw labels
        # Labels aligned to nodelist
        labels_ordered = {n: node_labels.get(n, str(n)) for n in nodelist}
        nx.draw_networkx_labels(
            G,
            pos,
            labels=labels_ordered,
            ax=ax,
            font_size=8,
            font_weight="bold",
            font_color="black",
        )

        # Add title and legend
        # Compute stats if not provided
        num_nodes = int(graph_data.get("num_nodes", len(G.nodes())))
        num_edges = int(graph_data.get("num_edges", len(G.edges())))
        num_documents = int(
            graph_data.get(
                "num_documents",
                sum(1 for _, d in G.nodes(data=True) if d.get("label") == "document"),
            )
        )
        ax.set_title(
            f"Graph Visualization\n"
            f"Nodes: {num_nodes}, "
            f"Edges: {num_edges}, "
            f"Documents: {num_documents}",
            fontsize=14,
            fontweight="bold",
            pad=20,
        )

        # Create legend
        from matplotlib.patches import Patch

        legend_elements = []
        for node_type, color in color_map.items():
            legend_elements.append(
                Patch(facecolor=color, edgecolor="black", label=node_type.capitalize())
            )

        ax.legend(
            handles=legend_elements,
            loc="upper left",
            framealpha=0.9,
            title="Node Types",
        )

        ax.axis("off")
        plt.tight_layout()

        return self._finalize_plot(fig, folder_path, show)

draw_graph(corpus=None, folder_path=None, show=True, layout='spring')

Draw graph visualization from corpus metadata.

Creates a visualization of the graph structure showing documents, keywords, clusters, and metadata nodes along with their relationships.

Parameters:

Name Type Description Default
corpus Corpus | None

Corpus with 'graph' metadata. If None, uses self.corpus

None
folder_path str | None

Path to save the figure. If None, uses self.folder_path

None
show bool

Whether to display the plot

True
layout str

Graph layout algorithm ('spring', 'circular', 'kamada_kawai', 'spectral')

'spring'

Returns:

Type Description
Figure

Matplotlib Figure object

Raises:

Type Description
ValueError

If corpus or graph metadata is missing

Source code in src/crisp_t/visualize.py
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
def draw_graph(
    self,
    corpus: Corpus | None = None,
    folder_path: str | None = None,
    show: bool = True,
    layout: str = "spring",
) -> Figure:
    """
    Draw graph visualization from corpus metadata.

    Creates a visualization of the graph structure showing documents,
    keywords, clusters, and metadata nodes along with their relationships.

    Args:
        corpus: Corpus with 'graph' metadata. If None, uses self.corpus
        folder_path: Path to save the figure. If None, uses self.folder_path
        show: Whether to display the plot
        layout: Graph layout algorithm ('spring', 'circular', 'kamada_kawai', 'spectral')

    Returns:
        Matplotlib Figure object

    Raises:
        ValueError: If corpus or graph metadata is missing
    """
    if corpus is None:
        corpus = self.corpus

    if corpus is None:
        raise ValueError("No corpus provided")

    if "graph" not in corpus.metadata:
        raise ValueError(
            "Corpus metadata does not contain 'graph' data. Run graph generation first."
        )

    graph_data = corpus.metadata["graph"]
    nodes = graph_data["nodes"]
    edges = graph_data["edges"]

    if not nodes:
        raise ValueError("No nodes found in graph data")

    # Create NetworkX graph
    G = nx.Graph()

    # Add nodes with their labels (store as maps keyed by node id)
    node_labels: dict[str, str] = {}
    node_color_map_by_id: dict[str, str] = {}
    node_size_map_by_id: dict[str, float] = {}

    # Color mapping for different node types
    color_map = {
        "document": "#FF6B6B",  # Red
        "keyword": "#4ECDC4",  # Teal
        "cluster": "#95E1D3",  # Light green
        "metadata": "#FFD93D",  # Yellow
    }

    for node in nodes:
        node_id = str(node.get("id"))
        label = node.get("label", "metadata")
        properties = node.get("properties", {})

        G.add_node(node_id, label=label, **properties)

        # Set node label (use name property if available)
        if "name" in properties and properties["name"]:
            node_labels[node_id] = str(properties["name"])
        else:
            # For keywords, remove the "keyword:" prefix
            if node_id.startswith("keyword:"):
                node_labels[node_id] = node_id.replace("keyword:", "")
            elif node_id.startswith("cluster:"):
                node_labels[node_id] = f"C{node_id.replace('cluster:', '')}"
            elif node_id.startswith("metadata:"):
                node_labels[node_id] = "M"
            else:
                node_labels[node_id] = node_id

        # Set node color based on type
        node_color_map_by_id[node_id] = color_map.get(label, "#CCCCCC")

        # Set node size based on type (documents larger)
        if label == "document":
            node_size_map_by_id[node_id] = 800.0
        elif label == "keyword":
            node_size_map_by_id[node_id] = 500.0
        elif label == "cluster":
            node_size_map_by_id[node_id] = 600.0
        else:
            node_size_map_by_id[node_id] = 400.0

    # Add edges
    for edge in edges:
        source = str(edge.get("source"))
        target = str(edge.get("target"))
        # If edge introduces unknown nodes, add with default properties
        if source not in G:
            G.add_node(source, label="metadata")
            node_labels[source] = (
                source if not source.startswith("metadata:") else "M"
            )
            node_color_map_by_id[source] = color_map.get("metadata", "#CCCCCC")
            node_size_map_by_id[source] = 400.0
        if target not in G:
            G.add_node(target, label="metadata")
            node_labels[target] = (
                target if not target.startswith("metadata:") else "M"
            )
            node_color_map_by_id[target] = color_map.get("metadata", "#CCCCCC")
            node_size_map_by_id[target] = 400.0
        G.add_edge(source, target)

    # Create figure
    fig, ax = plt.subplots(figsize=(16, 12))

    # Choose layout algorithm
    if layout == "spring":
        pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
    elif layout == "circular":
        pos = nx.circular_layout(G)
    elif layout == "kamada_kawai":
        pos = nx.kamada_kawai_layout(G)
    elif layout == "spectral":
        pos = nx.spectral_layout(G)
    else:
        pos = nx.spring_layout(G, seed=42)

    # Draw edges first (so they appear behind nodes)
    nx.draw_networkx_edges(
        G,
        pos,
        ax=ax,
        edge_color="#CCCCCC",
        width=1.5,
        alpha=0.6,
    )

    # Build aligned arrays for node attributes
    nodelist = list(G.nodes())
    node_colors = [node_color_map_by_id.get(n, "#CCCCCC") for n in nodelist]
    node_sizes = np.asarray(
        [float(node_size_map_by_id.get(n, 400.0)) for n in nodelist]
    )
    # Draw nodes with explicit nodelist
    nx.draw_networkx_nodes(
        G,
        pos,
        nodelist=nodelist,
        ax=ax,
        node_color=node_colors,
        node_size=node_sizes,
        alpha=0.9,
        edgecolors="black",
        linewidths=1.5,
    )

    # Draw labels
    # Labels aligned to nodelist
    labels_ordered = {n: node_labels.get(n, str(n)) for n in nodelist}
    nx.draw_networkx_labels(
        G,
        pos,
        labels=labels_ordered,
        ax=ax,
        font_size=8,
        font_weight="bold",
        font_color="black",
    )

    # Add title and legend
    # Compute stats if not provided
    num_nodes = int(graph_data.get("num_nodes", len(G.nodes())))
    num_edges = int(graph_data.get("num_edges", len(G.edges())))
    num_documents = int(
        graph_data.get(
            "num_documents",
            sum(1 for _, d in G.nodes(data=True) if d.get("label") == "document"),
        )
    )
    ax.set_title(
        f"Graph Visualization\n"
        f"Nodes: {num_nodes}, "
        f"Edges: {num_edges}, "
        f"Documents: {num_documents}",
        fontsize=14,
        fontweight="bold",
        pad=20,
    )

    # Create legend
    from matplotlib.patches import Patch

    legend_elements = []
    for node_type, color in color_map.items():
        legend_elements.append(
            Patch(facecolor=color, edgecolor="black", label=node_type.capitalize())
        )

    ax.legend(
        handles=legend_elements,
        loc="upper left",
        framealpha=0.9,
        title="Node Types",
    )

    ax.axis("off")
    plt.tight_layout()

    return self._finalize_plot(fig, folder_path, show)

draw_tdabm(corpus=None, folder_path=None, show=True)

Draw TDABM (Topological Data Analysis Ball Mapper) visualization.

Creates a 2D graph showing landmark points as circles: - Circle size is proportional to the count of points in the ball - Circle color represents mean y value (red for low, purple for high) - Lines connect landmark points with non-empty intersections

Based on the algorithm by Rudkin and Dlotko (2024).

Parameters:

Name Type Description Default
corpus Corpus | None

Corpus with 'tdabm' metadata. If None, uses self.corpus

None
folder_path str | None

Path to save the figure. If None, uses self.folder_path

None
show bool

Whether to display the plot

True

Returns:

Type Description
Figure

Matplotlib Figure object

Source code in src/crisp_t/visualize.py
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
def draw_tdabm(
    self,
    corpus: Corpus | None = None,
    folder_path: str | None = None,
    show: bool = True,
) -> Figure:
    """
    Draw TDABM (Topological Data Analysis Ball Mapper) visualization.

    Creates a 2D graph showing landmark points as circles:
    - Circle size is proportional to the count of points in the ball
    - Circle color represents mean y value (red for low, purple for high)
    - Lines connect landmark points with non-empty intersections

    Based on the algorithm by Rudkin and Dlotko (2024).

    Args:
        corpus: Corpus with 'tdabm' metadata. If None, uses self.corpus
        folder_path: Path to save the figure. If None, uses self.folder_path
        show: Whether to display the plot

    Returns:
        Matplotlib Figure object
    """
    if corpus is None:
        corpus = self.corpus

    if corpus is None:
        raise ValueError("No corpus provided")

    if "tdabm" not in corpus.metadata:
        raise ValueError(
            "Corpus metadata does not contain 'tdabm' data. Run TDABM analysis first."
        )

    tdabm_data = corpus.metadata["tdabm"]
    landmarks = tdabm_data["landmarks"]

    if not landmarks:
        raise ValueError("No landmarks found in TDABM data")

    # Create figure
    fig, ax = plt.subplots(figsize=(12, 10))

    # Collect all landmark locations
    locations = [landmark["location"] for landmark in landmarks]
    counts = [landmark["count"] for landmark in landmarks]
    mean_ys = [landmark["mean_y"] for landmark in landmarks]
    landmark_ids = [landmark["id"] for landmark in landmarks]

    # Perform PCA to reduce to 2 components (PC1, PC2)
    from sklearn.decomposition import PCA

    locations_array = np.array(locations)
    if locations_array.shape[1] < 2:
        # If only 1D, pad with zeros
        locations_array = np.pad(locations_array, ((0, 0), (0, 1)), mode="constant")
    pca = PCA(n_components=2)
    positions = pca.fit_transform(locations_array)

    # Normalize mean_y for color mapping (red=0, purple=max)
    min_y = min(mean_ys)
    max_y = max(mean_ys)

    if max_y - min_y > 0:
        normalized_ys = [(y - min_y) / (max_y - min_y) for y in mean_ys]
    else:
        normalized_ys = [0.5] * len(mean_ys)

    # Create color map: red (0) to green (1)
    colors = []
    for norm_y in normalized_ys:
        # Interpolate from red (1,0,0) to green (0,1,0)
        r = 1.0 - norm_y
        g = norm_y
        b = 0.0
        colors.append((r, g, b))

    # Draw connections first (so they appear behind circles)
    landmark_dict = {lm["id"]: idx for idx, lm in enumerate(landmarks)}

    for i, landmark in enumerate(landmarks):
        for connected_id in landmark["connections"]:
            if connected_id in landmark_dict:
                j = landmark_dict[connected_id]
                # Only draw each connection once (avoid duplicates)
                if i < j:
                    ax.plot(
                        [positions[i, 0], positions[j, 0]],
                        [positions[i, 1], positions[j, 1]],
                        "k-",
                        alpha=0.3,
                        linewidth=1,
                        zorder=1,
                    )

    # Normalize counts for circle sizes (scale for visibility)
    max_count = max(counts)
    min_count = min(counts)

    if max_count > min_count:
        # Scale sizes between 100 and 2000
        sizes = [
            100 + 1900 * (c - min_count) / (max_count - min_count) for c in counts
        ]
    else:
        sizes = [500] * len(counts)

    # Draw circles for landmarks
    scatter = ax.scatter(
        positions[:, 0],
        positions[:, 1],
        s=sizes,
        c=colors,
        alpha=0.6,
        edgecolors="black",
        linewidths=1.5,
        zorder=2,
    )

    # Add count and mean_y as label inside each circle
    for i, (pos, count, mean_y) in enumerate(zip(positions, counts, mean_ys)):
        ax.annotate(
            f"{count}\n{mean_y:.2f}",
            xy=pos,
            xytext=(0, 0),
            textcoords="offset points",
            ha="center",
            va="center",
            fontsize=8,
            fontweight="bold",
            zorder=3,
        )

    # Set labels and title
    x_vars = tdabm_data.get("x_variables", [])
    y_var = tdabm_data.get("y_variable", "y")

    # Axis labels reflect PCA components
    ax.set_xlabel("PC1", fontsize=12)
    ax.set_ylabel("PC2", fontsize=12)

    ax.set_title(
        f"TDABM Visualization\n"
        f'Y variable: {y_var}, Radius: {tdabm_data.get("radius", 0.3)}\n'
        f"Landmarks: {len(landmarks)}",
        fontsize=14,
        fontweight="bold",
    )

    # Add colorbar for mean_y (red to green)
    sm = plt.cm.ScalarMappable(
        cmap=mcolors.LinearSegmentedColormap.from_list(
            "red_green", ["red", "green"]
        ),
        norm=mcolors.Normalize(vmin=min_y, vmax=max_y),
    )
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax)
    cbar.set_label(f"Mean {y_var}", fontsize=12)

    # Add legend for circle sizes
    # Create dummy scatter plots for legend
    legend_counts = [min_count, (min_count + max_count) / 2, max_count]
    legend_sizes = []
    for c in legend_counts:
        if max_count > min_count:
            size = 100 + 1900 * (c - min_count) / (max_count - min_count)
        else:
            size = 500
        legend_sizes.append(size)

    legend_elements = []
    for size, count in zip(legend_sizes, legend_counts):
        legend_elements.append(
            plt.scatter(
                [],
                [],
                s=size,
                c="gray",
                alpha=0.6,
                edgecolors="black",
                linewidths=1.5,
                label=f"{int(count)} points",
            )
        )

    ax.legend(
        handles=legend_elements,
        title="Ball Size",
        loc="upper right",
        framealpha=0.9,
    )

    ax.grid(True, alpha=0.3)
    ax.set_aspect("equal", adjustable="box")

    plt.tight_layout()

    return self._finalize_plot(fig, folder_path, show)

get_lda_viz(lda_model, corpus_bow, dictionary, folder_path=None, mds='tsne', lambda_val=0.6, show=True)

Generate an interactive LDA visualization using pyLDAvis.

Parameters:

Name Type Description Default
lda_model

The trained LDA model

required
corpus_bow

Bag of words corpus

required
dictionary

Gensim dictionary

required
folder_path str | None

Path to save the HTML visualization

None
mds str

Dimension reduction method ('tsne', 'mmds', or 'pcoa')

'tsne'
lambda_val float

Lambda parameter for relevance metric (default: 0.6). Mettler et al. (2025) performed several experiments to identify the optimal value of λ, which turned out to be 0.6.

0.6
show bool

Whether to display the visualization

True

Returns:

Type Description
str | None

HTML string of the visualization if successful, None otherwise

Raises:

Type Description
ImportError

If pyLDAvis is not installed

ValueError

If required inputs are missing

Source code in src/crisp_t/visualize.py
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
def get_lda_viz(
    self,
    lda_model,
    corpus_bow,
    dictionary,
    folder_path: str | None = None,
    mds: str = "tsne",
    lambda_val: float = 0.6,
    show: bool = True,
) -> str | None:
    """
    Generate an interactive LDA visualization using pyLDAvis.

    Args:
        lda_model: The trained LDA model
        corpus_bow: Bag of words corpus
        dictionary: Gensim dictionary
        folder_path: Path to save the HTML visualization
        mds: Dimension reduction method ('tsne', 'mmds', or 'pcoa')
        lambda_val: Lambda parameter for relevance metric (default: 0.6).
                   Mettler et al. (2025) performed several experiments to identify
                   the optimal value of λ, which turned out to be 0.6.
        show: Whether to display the visualization

    Returns:
        HTML string of the visualization if successful, None otherwise

    Raises:
        ImportError: If pyLDAvis is not installed
        ValueError: If required inputs are missing
    """
    if not PYLDAVIS_AVAILABLE:
        raise ImportError(
            "pyLDAvis is not installed. Install it with: pip install pyLDAvis"
        )

    if lda_model is None:
        raise ValueError("LDA model is required")
    if corpus_bow is None:
        raise ValueError("Corpus bag of words is required")
    if dictionary is None:
        raise ValueError("Dictionary is required")

    try:
        # Prepare the visualization data
        vis_data = gensimvis.prepare(
            lda_model,
            corpus_bow,
            dictionary,
            mds=mds,
            R=30,
            lambda_step=0.01,
            plot_opts={"xlab": "PC1", "ylab": "PC2"},
        )

        # Save to HTML file if path provided
        if folder_path:
            output_path = Path(folder_path)
            if output_path.parent:
                output_path.parent.mkdir(parents=True, exist_ok=True)
            pyLDAvis.save_html(vis_data, str(output_path))
            logger.info(f"LDA visualization saved to {output_path}")

        # Return HTML string for embedding or further use
        html_string = pyLDAvis.prepared_data_to_html(vis_data)
        return html_string

    except Exception as e:
        logger.error(f"Error generating LDA visualization: {e}")
        raise