g <- random.graph.game(20, 0.3)
g <- set_vertex_attr(g, name = "node_cat_feat", value = sample(c("green", "red"), 20, TRUE))
g <- set_edge_attr(g, name = "edge_cat_feat", value = sample(c("black", "white", "gray"), ecount(g), TRUE))
g <- set_vertex_attr(g, name = "node_regr_feat", value = rnorm(20, 10, 3))
g <- set_edge_attr(g, name = "edge_regr_feat", value = rnorm(ecount(g), 5, 1))

new_g <- add.vertices(g, 3)
new_g <- add.edges(new_g, c(1, 21, 10, 22, 20, 23, 15, 22))

test_that("Correct outcome format and size for base outcome1",
          {
            skip_if_not_installed("torch")
            skip_if(is.na(tryCatch(cuda_is_available(), error = function(e) NA)))
            outcome1 <- spinner(g, "node", c("node_cat_feat", "node_regr_feat"), c("edge_cat_feat", "edge_regr_feat"))
            expect_equal(class(outcome1), "list")
            expect_equal(length(outcome1), 8)
            expect_equal(names(outcome1), c("graph", "model_description", "model_summary", "pred_fun", "cv_errors", "summary_errors", "history", "time_log"))
            expect_equal(is.igraph(outcome1$graph), TRUE)
            expect_equal(is.function(outcome1$pred_fun), TRUE)
            expect_equal(is.character(outcome1$model_description), TRUE)
            expect_equal(is.list(outcome1$model_summary), TRUE)
            expect_equal(dim(outcome1$cv_errors), c(3, 4))
            expect_equal(length(outcome1$summary_errors), 3)
            expect_equal(is.ggplot(outcome1$history), TRUE)
            expect_equal(class(outcome1$time_log)[1],"Period")
            expect_equal(dim(outcome1$pred_fun(new_g)[[1]]), c(3, 3))
          })


test_that("Correct outcome format and size for base outcome2",
          {
            skip_if_not_installed("torch")
            skip_if(is.na(tryCatch(cuda_is_available(), error = function(e) NA)))
            outcome2 <- spinner(g, "edge", c("node_cat_feat", "node_regr_feat"), c("edge_cat_feat", "edge_regr_feat"))
            expect_equal(class(outcome2), "list")
            expect_equal(length(outcome2), 8)
            expect_equal(names(outcome2), c("graph", "model_description", "model_summary", "pred_fun", "cv_errors", "summary_errors", "history", "time_log"))
            expect_equal(is.igraph(outcome2$graph), TRUE)
            expect_equal(is.function(outcome2$pred_fun), TRUE)
            expect_equal(is.character(outcome2$model_description), TRUE)
            expect_equal(is.list(outcome2$model_summary), TRUE)
            expect_equal(dim(outcome2$cv_errors), c(3, 4))
            expect_equal(length(outcome2$summary_errors), 3)
            expect_equal(is.ggplot(outcome2$history), TRUE)
            expect_equal(class(outcome2$time_log)[1],"Period")
            expect_equal(dim(outcome2$pred_fun(new_g)[[1]]), c(4, 5))
          })

test_that("Correct outcome format and size for base outcome2bis",
          {
            skip_if_not_installed("torch")
            skip_if(is.na(tryCatch(cuda_is_available(), error = function(e) NA)))
            outcome2bis <- spinner(g, "node", node_embedding_size = 10)
            expect_equal(class(outcome2bis), "list")
            expect_equal(length(outcome2bis), 8)
            expect_equal(names(outcome2bis), c("graph", "model_description", "model_summary", "pred_fun", "cv_errors", "summary_errors", "history", "time_log"))
            expect_equal(is.igraph(outcome2bis$graph), TRUE)
            expect_equal(is.function(outcome2bis$pred_fun), TRUE)
            expect_equal(is.character(outcome2bis$model_description), TRUE)
            expect_equal(is.list(outcome2bis$model_summary), TRUE)
            expect_equal(dim(outcome2bis$cv_errors), c(3, 4))
            expect_equal(length(outcome2bis$summary_errors), 3)
            expect_equal(is.ggplot(outcome2bis$history), TRUE)
            expect_equal(class(outcome2bis$time_log)[1],"Period")
            expect_equal(dim(outcome2bis$pred_fun(new_g)[[1]]), c(3, 11))
          })

test_that("Correct outcome format and size for base outcome3",
          {
            skip_if_not_installed("torch")
            skip_if(is.na(tryCatch(cuda_is_available(), error = function(e) NA)))
            outcome3 <- spinner_random_search(3, g, "edge", c("node_cat_feat", "node_regr_feat"), c("edge_cat_feat", "edge_regr_feat"), keep = TRUE)
            expect_equal(class(outcome3), "list")
            expect_equal(length(outcome3), 4)
            expect_equal(names(outcome3), c("random_search","best", "time_log", "all_models"))
            expect_equal(is.data.frame(outcome3$random_search), TRUE)
            expect_equal(dim(outcome3$random_search), c(3, 17))
            expect_equal(is.list(outcome3$best), TRUE)
            expect_equal(length(outcome3$best), 8)
            expect_equal(dim(outcome3$best$cv_errors), c(2, 4))
            expect_equal(length(outcome3$all_models), 3)
            expect_equal(class(outcome3$time_log)[1],"Period")
          })

test_that("Correct outcome format and size for base outcome3",
          {
            skip_if_not_installed("torch")
            skip_if(is.na(tryCatch(cuda_is_available(), error = function(e) NA)))
            outcome4 <- spinner_random_search(3, g, "edge", keep = FALSE)
            expect_equal(class(outcome4), "list")
            expect_equal(length(outcome4), 3)
            expect_equal(names(outcome4), c("random_search","best", "time_log"))
            expect_equal(is.data.frame(outcome4$random_search), TRUE)
            expect_equal(dim(outcome4$random_search), c(3, 19))
            expect_equal(is.list(outcome4$best), TRUE)
            expect_equal(length(outcome4$best), 8)
            expect_equal(dim(outcome4$best$cv_errors), c(2, 4))
            expect_equal(length(outcome4$all_models), 0)
            expect_equal(class(outcome4$time_log)[1],"Period")
          })
