API Reference πŸ“’οƒ

Core Runtime

These modules control configuration parsing, persistence, and the top-level benchmark entrypoint.

config.py

Configuration management for YAML-based configs, similar to the tf-binding project. Handles both YAML file loading and command-line argument parsing.

class scTimeBench.config.Config

Bases: object

Config class for both yaml and cli arguments.

class scTimeBench.config.CsvExportType(value)

Bases: Enum

An enumeration.

EMBEDDING = 'embedding'
GEX_PRED = 'gex_pred'
GRAPH_SIM = 'graph_sim'
class scTimeBench.config.CsvWriteMode(value)

Bases: Enum

An enumeration.

MERGE = 'merge'
SEPARATE = 'separate'
class scTimeBench.config.RunType(value)

Bases: Enum

An enumeration.

AUTO_TRAIN_TEST = 'auto_train_test'
EVAL_ONLY = 'eval_only'
PREPROCESS = 'preprocess'
TRAIN_ONLY = 'train_only'

Database manager using sqlite3.

This module provides a simple interface to interact with an SQLite database, including the setup of tables for storing: 1. Paths to processed datasets. 2. Paths to method checkpoints. 3. Paths to method predictions. 4. Metric results.

class scTimeBench.database.DatabaseManager(config: Config)

Bases: object

clear_tables()
close()
embedding_to_csv(output_csv_path, append=False)
get_dataset_id(method: MethodManager)
get_dataset_tag_from_id(dataset_id)
get_evals_per_method(method: MethodManager)
get_evals_per_metric(metric_name: str, metric_params: str)
get_method_output_path(method: MethodManager)
gex_pred_to_csv(output_csv_path, append=False)
graph_sim_to_csv(output_csv_path, append=False)
has_eval(method: MethodManager, metric_name: str, metric_params: str) bool
has_metric(name: str, parameters: str) bool
insert_dataset(dataset: BaseDataset)
insert_dataset_metric(dataset: BaseDataset, metric_name, metric_params, result)
insert_eval(method: MethodManager, metric_name: str, metric_params: str, result)
insert_method_output(method: MethodManager, output_path: str)
insert_metric(name: str, parameters: str)
print_all()
return_all()

main.py. Entrypoint for measuring trajectories in single-cell data, particularly involving gene regulatory networks and cell lineage information.

scTimeBench.main.iterate_csv_types(config: Config)
scTimeBench.main.main()

Main entrypoint for the scTimeBench (crispy-fishstick) package.

scTimeBench.main.plot(config: Config)
scTimeBench.main.print_available(config: Config)

Print available datasets and metrics.

scTimeBench.main.run_metrics(config: Config)

Run the specified metrics based on the provided configuration.

scTimeBench.main.to_csv(config: Config)
scTimeBench.main.view_evals_by_method(config: Config)

View evaluations grouped by method.

scTimeBench.main.view_evals_by_metric(config: Config)

View evaluations grouped by metric.

Dataset Infrastructure

These modules define dataset loading, preprocessing, shared constants, and utility helpers used throughout the benchmark.

Definitions to be shared by the benchmark and the method implementations.

class scTimeBench.shared.constants.ObservationColumns(value)

Bases: Enum

An enumeration.

CELL_TYPE = 'scTimeBench_cell_type'
TIMEPOINT = 'scTimeBench_timepoint'
class scTimeBench.shared.constants.RequiredOutputFiles(value)

Bases: Enum

An enumeration.

EMBEDDING = 'embedding.npy'
FROM_ZERO_TO_END_PRED_GEX = 'from_zero_to_end_predicted_gene_expression.h5ad'
NEXT_CELLTYPE = 'next_cell_type.parquet'
NEXT_TIMEPOINT_EMBEDDING = 'next_timepoint_embedding.npy'
NEXT_TIMEPOINT_GENE_EXPRESSION = 'next_timepoint_gene_expression.npy'
PRED_GRAPH = 'predicted_graph.npy'

Base preprocessor for datasets. Every metric will likely require different splits of the data, so this base class will define the necessary interface for dataset preprocessing.

class scTimeBench.shared.dataset.base.BaseDataset(dataset_dict, dataset_preprocessors: list[BaseDatasetPreprocessor], output_dir)

Bases: object

create_dataset_dir()

Create a directory for this dataset configuration under the given base path.

encode_dataset_dict()

Generate a string representation of the dataset configuration.

This can be used to cache processed datasets.

encode_preprocessors(i=None)

Generate a string representation of the applied dataset preprocessors and their parameters.

This can be used to cache processed datasets.

get_checkpoint_dir(i)

We define a checkpoint as the ith preprocessor in the pipeline. This is used to save intermediate results that take a while to get to (such as pseudotime estimation).

get_dataset_dir()

Get a unique directory name for this dataset configuration, which can be used for caching. This is based on the dataset name, the encoded dataset dictionary, and the encoded preprocessors.

It should be a hashable string that uniquely identifies the dataset configuration and applied preprocessors, so that we can cache processed datasets effectively.

get_name()

Get the name of the dataset from the configuration.

load_data()

This ensures that the dataset loading is done properly.

We require the following: 1. Load the data from the source. 2. Include observation metadata of cell_type, and timepoint. 3. Drop everything else not required, to speed up processing. 4. Apply the dataset preprocessors provided. 5. Return the train and test splits.

Update: > Because I’m getting annoyed about the dependency hell we need for psupertime… > I’ve decided that the best way forward is to simply add pypsupertime as a possible > thing to have, but not necessary. Instead, we would require them to run the preprocessing > ahead of time, which is what this function does – loads the data (running them through the preprocessor) > and saving them to their respective output directory.

requires_caching()

Some datasets might require caching because they have preprocessors that take a long time to run (e.g., pseudotime estimation). By default, we assume that datasets do not require caching, but this can be overridden by specific datasets if necessary.

class scTimeBench.shared.dataset.base.BaseDatasetPreprocessor(dataset_dict)

Bases: object

preprocess(ann_data, **kwargs)

Subclasses should implement this method to preprocessor and split the dataset according to the metric’s requirements.

requires_caching()

By default, most preprocessors should be simple and not require external packages.

scTimeBench.shared.dataset.base.register_dataset(cls)

Decorator to register a dataset class in the DATASET_REGISTRY.

scTimeBench.shared.dataset.base.register_dataset_preprocessor(cls)

Decorator to register a dataset preprocessor class in the DATASET_PREPROCESSOR_REGISTRY.

Shared utility functions for loading datasets and output files.

scTimeBench.shared.utils.animate()
scTimeBench.shared.utils.block_interrupts()
scTimeBench.shared.utils.cheeky_message(sig, frame)
scTimeBench.shared.utils.clear_dataset_cache()

Clear the in-memory dataset cache.

scTimeBench.shared.utils.get_dataset(output_path)

Get the dataset from the pickled dataset file in output_path.

Args:

output_path: Path to the method output directory

Returns:

The dataset object loaded from the pickled file

scTimeBench.shared.utils.is_log_normalized_to_counts(ann_data, counts=10000)

Heuristic to determine if the data is log-normalized to a certain counts threshold. Checks if ann_data.X is raw and if not, then checks to see that the data is log-normalized to counts=10_000.

Args:

ann_data: The AnnData object to check counts: The expected counts value (default is 10_000)

Returns:

True if the data is log-normalized to the expected counts, False otherwise

scTimeBench.shared.utils.is_raw(ann_data: AnnData)

Returns whether the data is raw (i.e. not log-normalized) by checking that: 1. All the data is non-negative 2. All the data is integer-valued

scTimeBench.shared.utils.load_output_file(output_path, required_output: RequiredOutputFiles)

Load a method output file from output_path.

Args:

output_path: Path to the method output directory required_output: RequiredOutputFiles enum value specifying which file to load

Returns:

For .npy files: numpy array For .parquet files: pandas DataFrame

scTimeBench.shared.utils.load_test_dataset(output_path)

Load the test dataset from the pickled dataset file in output_path.

Args:

output_path: Path to the method output directory

Returns:

The test AnnData object from the dataset

scTimeBench.shared.utils.restore_interrupts()

Helper function for miscellaneous tasks.

scTimeBench.shared.helpers.parse_cell_lineage(file_path, equivalence_file_path=None)

Parse a cell lineage file and create a dictionary mapping source to root.

Parameters:

file_pathstr

Path to the lineage file (split by =>)

Returns:

dict

Dictionary mapping canonicalized cell types to their descendants

scTimeBench.shared.helpers.parse_equivalence(file_path)

Parse a cell equivalence file and create a dictionary mapping equivalent names.

Parameters:

file_pathstr

Path to the equivalence file (split by ,)

Returns:

dict

Dictionary mapping alias cell type names to their canonical name.

Method Execution

These modules provide the method runner interface and the helper used by the benchmark to launch methods and collect their outputs.

Note: for this file only, this will be used by other methods as a base class And so its context is outside the src/ folder, so we need to use scTimeBench.* imports instead of relative imports.

class scTimeBench.method_utils.method_runner.BaseMethod(yaml_config)

Bases: object

generate(test_ann_data)

Main generation method that dispatches to individual output generators. Each output is saved to its own file under self.output_path.

generate_embedding(test_ann_data) ndarray

Generate embeddings for the current timepoint. Returns: np.ndarray of shape (n_cells, embedding_dim)

generate_next_cell_type(test_ann_data) DataFrame

Generate next cell type predictions. Returns: pd.DataFrame with cell type predictions

generate_next_tp_embedding(test_ann_data) ndarray

Generate embeddings for the next timepoint. Returns: np.ndarray of shape (n_cells, embedding_dim)

generate_next_tp_gex(test_ann_data) ndarray

Generate gene expression for the next timepoint. Returns: np.ndarray of shape (n_cells, n_genes)

generate_pred_graph(test_ann_data) ndarray

Generate predicted graph. Returns: np.ndarray representing the predicted graph

generate_zero_to_end_pred_gex(first_tp_cells, all_tps) AnnData

Generate predicted gene expression from the first to the last timepoint. Returns: AnnData object with predicted gene expression across all timepoints

train(ann_data, all_tps=None)
scTimeBench.method_utils.method_runner.get_parser()
scTimeBench.method_utils.method_runner.main(method_class: BaseMethod)
scTimeBench.method_utils.method_runner.process_yaml(yaml_path)
class scTimeBench.method_utils.ot_method_runner.BaseOTMethod(yaml_config)

Bases: BaseMethod

Base class for OT-based methods.

generate_embedding(test_ann_data) ndarray

Generate PCA embeddings from gene expression data.

generate_next_cell_type(test_ann_data) DataFrame

Generate next cell type predictions using transport plan.

generate_next_tp_embedding(test_ann_data) ndarray

Generate embeddings for the next timepoint using transport plan.

generate_next_tp_gex(test_ann_data) ndarray

Generate gene expression for the next timepoint using transport plan.

get_transport_plan(source_data, target_data)

Given source and target data, compute the transport plan. Subclasses representing OT methods should implement this method.

Parameters:

source_datanp.ndarray

Source data matrix (cells x features)

target_datanp.ndarray

Target data matrix (cells x features)

Returns:

np.ndarray

Transport plan matrix (source cells x target cells)

train(ann_data, all_tps=None)

Metric Framework

These modules define the metric base class and the method manager used to bind datasets to method outputs during evaluation.

Base class for all metrics. They should all implement the eval method, and depend on the dataset that they belong to.

class scTimeBench.metrics.base.BaseMetric(config: Config, db_manager: DatabaseManager, metric_config: dict)

Bases: object

final eval()

Evaluation function that handles the calling of submetrics if applicable.

Basically it happens as follows:

  1. If there are submetrics defined, we create an instance of each submetric.

  2. We call the _eval function of each submetric.

  3. From this _eval function, we further call the _submetric_eval function that each subclass must implement.

scTimeBench.metrics.base.register_metric(cls)

Decorator to register a metric class in the METRIC_REGISTRY.

scTimeBench.metrics.base.skip_metric(cls)

Decorator to register a skip metric class in the SKIP_METRIC_REGISTRY.

Method Base Class.

class scTimeBench.metrics.method_manager.MethodManager(config, dataset: BaseDataset)

Bases: object

train_and_test(yaml_config_path)

Runs the train and test script provided in the config.

Trajectory Inference

These modules implement the trajectory inference abstractions and concrete inference strategies used by the metrics.

Base trajectory inference model.

This is the base class for all trajectory inference models, i.e. given an ann data and its timepoints, we want to infer the trajectory structure.

Examples are the kNN graph-based methods, or the optimal transport based methods.

class scTimeBench.trajectory_infer.base.BaseTrajectoryInferMethod(traj_config)

Bases: object

encode()

Hash the trajectory inference method based on its class name and parameters.

encode_for_classifier()

Hash the trajectory inference method for the classifier based on its class name and parameters.

This is different from the regular encode because we want to ignore the from_tp_zero because that should be shared regardless of the from_tp_zero setting.

final infer_trajectory(output_path, per_tp=False)

Infer the trajectory using the kNN graph-based method.

  1. Separate each embedding by time.

  2. Find the k nearest neighbors in the next time point embedding space.

  3. Consolidate the cell types per time point based on the kNN results.

predict_next_tp(output_path, test_ann_data=None, traj_infer_path=None)

Predict the next timepoint cell types using the trajectory inference model.

supports_gex()

Function to be overwritten if the trajectory inference method can support By default, we assume it does not.

final train_and_predict(output_path, train_only=False)

Trains and predicts using the trajectory inference model.

train_and_predict_k_fold_cv(output_path, k)

Does the train and predict with k-fold cross validation.

We store everything under traj_infer_path/k_fold_<k>/fold_<i>/

uses_gene_expr()
class scTimeBench.trajectory_infer.base.TrajectoryInferenceMethodFactory

Bases: object

get_trajectory_infer_method(traj_config) BaseTrajectoryInferMethod
scTimeBench.trajectory_infer.base.register_trajectory_inference_method(cls)

Decorator to register a trajectory inference method.

Classifier implementation for trajectory inference.

class scTimeBench.trajectory_infer.classifier.CellTypist(traj_config)

Bases: BaseTrajectoryInferMethod

class scTimeBench.trajectory_infer.classifier.Classifier(traj_config)

Bases: BaseTrajectoryInferMethod

class scTimeBench.trajectory_infer.classifier.ClassifierTypes(value)

Bases: Enum

An enumeration.

BOOSTING = 'boosting'
RANDOM_FOREST = 'random_forest'

kNN implementation for trajectory inference.

class scTimeBench.trajectory_infer.kNN.kNN(traj_config)

Bases: BaseTrajectoryInferMethod

get_kNN_graph(output_path)

Function to get the kNN graph used in the trajectory inference.

This can be useful for visualization or further analysis.

class scTimeBench.trajectory_infer.kNN.kNNStrategy(value)

Bases: Enum

An enumeration.

MAJORITY_VOTE = 'majority_vote'
WEIGHTED_AVERAGE = 'weighted_average'

OT implementation for trajectory inference.

class scTimeBench.trajectory_infer.ot.OptimalTransport(traj_config)

Bases: BaseTrajectoryInferMethod

WARNING: This is untested and deprecated.

Please switch to either kNN or Classifier with scikit-learn based classifiers for better performance and maintainability.

cell_types_to_one_hot(cell_types)

Given a list of cell types, convert to one-hot encoding

get_ot_labels(true_embed, pred_embed, one_hot_labels)

Given the true embeddings, predicted embeddings and one-hot encoding of true cell types, get the transport plan using optimal transport

soft_labels_to_cell_types(labels, index_to_type)

Given the labels from get_ot_labels, and the index to type mapping, convert the soft labels to hard cell type labels

supports_gex()

By default OT does not have enough capacity to support gene expression data, as it is primarily designed for embedding-based trajectory inference. This is because OT can be computationally intensive and may not scale well with high-dimensional gene expression data, leading to longer runtimes and potential memory issues.