[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[win-pv-devel] [PATCH 3/4] Check 'Reboot' value in the 'Request' key



If the 'Reboot' value is set with a service name then pop up a message in
the active session indicating that the specified service requires a system
reboot in order to complete installation. If the session user responds
affirmatively to the message then initiate a reboot.

Signed-off-by: Paul Durrant <paul.durrant@xxxxxxxxxx>
---
 src/monitor/monitor.c | 413 ++++++++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 403 insertions(+), 10 deletions(-)

diff --git a/src/monitor/monitor.c b/src/monitor/monitor.c
index 6c66518..9e4555f 100644
--- a/src/monitor/monitor.c
+++ b/src/monitor/monitor.c
@@ -33,6 +33,9 @@
 #include <tchar.h>
 #include <stdlib.h>
 #include <strsafe.h>
+#include <wtsapi32.h>
+#include <malloc.h>
+#include <assert.h>
 
 #include <version.h>
 
@@ -48,6 +51,7 @@ typedef struct _MONITOR_CONTEXT {
     HANDLE                  StopEvent;
     HANDLE                  RequestEvent;
     HKEY                    RequestKey;
+    BOOL                    RebootPending;
 } MONITOR_CONTEXT, *PMONITOR_CONTEXT;
 
 MONITOR_CONTEXT MonitorContext;
@@ -256,6 +260,318 @@ MonitorCtrlHandlerEx(
     return ERROR_CALL_NOT_IMPLEMENTED;
 }
 
+static const CHAR *
+WTSStateName(
+    IN  DWORD   State
+    )
+{
+#define _STATE_NAME(_State) \
+    case WTS ## _State: \
+        return #_State
+
+    switch (State) {
+    _STATE_NAME(Active);
+    _STATE_NAME(Connected);
+    _STATE_NAME(ConnectQuery);
+    _STATE_NAME(Shadow);
+    _STATE_NAME(Disconnected);
+    _STATE_NAME(Idle);
+    _STATE_NAME(Listen);
+    _STATE_NAME(Reset);
+    _STATE_NAME(Down);
+    _STATE_NAME(Init);
+    default:
+        break;
+    }
+
+    return "UNKNOWN";
+
+#undef  _STATE_NAME
+}
+
+static VOID
+DoReboot(
+    VOID
+    )
+{
+    (VOID) InitiateSystemShutdownEx(NULL,
+                                    NULL,
+                                    0,
+                                    TRUE,
+                                    TRUE,
+                                    SHTDN_REASON_MAJOR_OPERATINGSYSTEM |
+                                    SHTDN_REASON_MINOR_INSTALLATION |
+                                    SHTDN_REASON_FLAG_PLANNED);
+}
+
+static VOID
+PromptForReboot(
+    IN PTCHAR           DriverName
+    )
+{
+    PMONITOR_CONTEXT    Context = &MonitorContext;
+    HRESULT             Result;
+    TCHAR               ServiceKeyName[MAX_PATH];
+    HKEY                ServiceKey;
+    DWORD               MaxValueLength;
+    DWORD               DisplayNameLength;
+    PTCHAR              DisplayName;
+    DWORD               Type;
+    TCHAR               Title[] = TEXT(VENDOR_NAME_STR);
+    TCHAR               Message[MAXIMUM_BUFFER_SIZE];
+    PWTS_SESSION_INFO   SessionInfo;
+    DWORD               Count;
+    DWORD               Index;
+    BOOL                Success;
+    HRESULT             Error;
+
+    Log("====> (%s)", DriverName);
+
+    Result = StringCbPrintf(ServiceKeyName,
+                            MAX_PATH,
+                            SERVICES_KEY "\\%s",
+                            DriverName);
+    assert(SUCCEEDED(Result));
+
+    Error = RegOpenKeyEx(HKEY_LOCAL_MACHINE,
+                         ServiceKeyName,
+                         0,
+                         KEY_READ,
+                         &ServiceKey);
+    if (Error != ERROR_SUCCESS) {
+        SetLastError(Error);
+        goto fail1;
+    }
+
+    Error = RegQueryInfoKey(ServiceKey,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            &MaxValueLength,
+                            NULL,
+                            NULL);
+    if (Error != ERROR_SUCCESS) {
+        SetLastError(Error);
+        goto fail2;
+    }
+
+    DisplayNameLength = MaxValueLength + sizeof (TCHAR);
+
+    DisplayName = calloc(1, DisplayNameLength);
+    if (DisplayName == NULL)
+        goto fail3;
+
+    Error = RegQueryValueEx(ServiceKey,
+                            "DisplayName",
+                            NULL,
+                            &Type,
+                            (LPBYTE)DisplayName,
+                            &DisplayNameLength);
+    if (Error != ERROR_SUCCESS) {
+        SetLastError(Error);
+        goto fail4;
+    }
+
+    if (Type != REG_SZ) {
+        SetLastError(ERROR_BAD_FORMAT);
+        goto fail5;
+    }
+
+    Result = StringCbPrintf(Message,
+                            MAXIMUM_BUFFER_SIZE,
+                            TEXT("%s needs to restart the system to "
+                                 "complete installation.\n"
+                                 "Press 'Yes' to restart the system "
+                                 "now or 'No' if you plan to restart "
+                                 "the system later.\n"),
+                            DisplayName);
+    assert(SUCCEEDED(Result));
+
+    Success = WTSEnumerateSessions(WTS_CURRENT_SERVER_HANDLE,
+                                   0,
+                                   1,
+                                   &SessionInfo,
+                                   &Count);
+
+    if (!Success)
+        goto fail6;
+
+    for (Index = 0; Index < Count; Index++) {
+        DWORD                   SessionId = SessionInfo[Index].SessionId;
+        PTCHAR                  Name = SessionInfo[Index].pWinStationName;
+        WTS_CONNECTSTATE_CLASS  State = SessionInfo[Index].State;
+        DWORD                   Response;
+
+        Log("[%u]: %s [%s]",
+            SessionId,
+            Name,
+            WTSStateName(State));
+
+        if (State != WTSActive)
+            continue;
+
+        Success = WTSSendMessage(WTS_CURRENT_SERVER_HANDLE,
+                                 SessionId,
+                                 Title,
+                                 sizeof (Title),
+                                 Message,
+                                 sizeof (Message),
+                                 MB_YESNO | MB_ICONEXCLAMATION,
+                                 0,
+                                 &Response,
+                                 TRUE);
+
+        if (!Success)
+            goto fail7;
+
+        Context->RebootPending = TRUE;
+
+        if (Response == IDYES)
+            DoReboot();
+
+        break;
+    }
+
+    WTSFreeMemory(SessionInfo);
+
+    free(DisplayName);
+
+    RegCloseKey(ServiceKey);
+
+    Log("<====");
+
+    return;
+
+fail7:
+    Log("fail7");
+
+    WTSFreeMemory(SessionInfo);
+
+fail6:
+    Log("fail6");
+
+fail5:
+    Log("fail5");
+
+fail4:
+    Log("fail4");
+
+    free(DisplayName);
+
+fail3:
+    Log("fail3");
+
+fail2:
+    Log("fail2");
+
+    RegCloseKey(ServiceKey);
+
+fail1:
+    Error = GetLastError();
+
+    {
+        PTCHAR  Message;
+        Message = GetErrorMessage(Error);
+        Log("fail1 (%s)", Message);
+        LocalFree(Message);
+    }
+}
+
+static VOID
+CheckRebootValue(
+    VOID
+    )
+{
+    PMONITOR_CONTEXT    Context = &MonitorContext;
+    HRESULT             Error;
+    DWORD               MaxValueLength;
+    DWORD               RebootLength;
+    PTCHAR              Reboot;
+    DWORD               Type;
+
+    Log("====>");
+
+    Error = RegQueryInfoKey(Context->RequestKey,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            NULL,
+                            &MaxValueLength,
+                            NULL,
+                            NULL);
+    if (Error != ERROR_SUCCESS) {
+        SetLastError(Error);
+        goto fail1;
+    }
+
+    RebootLength = MaxValueLength + sizeof (TCHAR);
+
+    Reboot = calloc(1, RebootLength);
+    if (Reboot == NULL)
+        goto fail2;
+
+    Error = RegQueryValueEx(Context->RequestKey,
+                            "Reboot",
+                            NULL,
+                            &Type,
+                            (LPBYTE)Reboot,
+                            &RebootLength);
+    if (Error != ERROR_SUCCESS) {
+        if (Error == ERROR_FILE_NOT_FOUND)
+            goto done;
+
+        SetLastError(Error);
+        goto fail3;
+    }
+
+    if (Type != REG_SZ) {
+        SetLastError(ERROR_BAD_FORMAT);
+        goto fail4;
+    }
+
+    if (!Context->RebootPending)
+        PromptForReboot(Reboot);
+
+    (VOID) RegDeleteValue(Context->RequestKey, "Reboot");
+
+done:
+    free(Reboot);
+
+    Log("<====");
+
+    return;
+
+fail4:
+    Log("fail4");
+
+fail3:
+    Log("fail3");
+
+    free(Reboot);
+
+fail2:
+    Log("fail2");
+
+fail1:
+    Error = GetLastError();
+
+    {
+        PTCHAR  Message;
+        Message = GetErrorMessage(Error);
+        Log("fail1 (%s)", Message);
+        LocalFree(Message);
+    }
+}
+
 static VOID
 CheckRequestKey(
     VOID
@@ -266,6 +582,8 @@ CheckRequestKey(
 
     Log("====>");
 
+    CheckRebootValue();
+
     Error = RegNotifyChangeKeyValue(Context->RequestKey,
                                     TRUE,
                                     REG_NOTIFY_CHANGE_LAST_SET,
@@ -290,6 +608,73 @@ fail1:
     }
 }
 
+static BOOL
+AcquireShutdownPrivilege(
+    VOID
+    )
+{
+    HANDLE              Token;
+    TOKEN_PRIVILEGES    New;
+    BOOL                Success;
+    HRESULT             Error;
+
+    Log("====>");
+
+    New.PrivilegeCount = 1;
+
+    Success = LookupPrivilegeValue(NULL,
+                                   SE_SHUTDOWN_NAME,
+                                   &New.Privileges[0].Luid);
+
+    if (!Success)
+        goto fail1;
+
+    New.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED;
+
+    Success = OpenProcessToken(GetCurrentProcess(),
+                               TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY,
+                               &Token);
+
+    if (!Success)
+        goto fail2;
+
+    Success = AdjustTokenPrivileges(Token,
+                                    FALSE,
+                                    &New,
+                                    0,
+                                    NULL,
+                                    NULL);
+
+    if (!Success)
+        goto fail3;
+
+    CloseHandle(Token);
+
+    Log("<====");
+
+    return TRUE;
+
+fail3:
+    Log("fail3");
+
+    CloseHandle(Token);
+
+fail2:
+    Log("fail2");
+
+fail1:
+    Error = GetLastError();
+
+    {
+        PTCHAR  Message;
+        Message = GetErrorMessage(Error);
+        Log("fail1 (%s)", Message);
+        LocalFree(Message);
+    }
+
+    return FALSE;
+}
+
 VOID WINAPI
 MonitorMain(
     _In_    DWORD       argc,
@@ -305,16 +690,21 @@ MonitorMain(
 
     Log("====>");
 
+    Success = AcquireShutdownPrivilege();
+
+    if (!Success)
+        goto fail1;
+
     Context->Service = RegisterServiceCtrlHandlerEx(MONITOR_NAME,
                                                     MonitorCtrlHandlerEx,
                                                     NULL);
     if (Context->Service == NULL)
-        goto fail1;
+        goto fail2;
 
     Context->EventLog = RegisterEventSource(NULL,
                                             MONITOR_NAME);
     if (Context->EventLog == NULL)
-        goto fail2;
+        goto fail3;
 
     Context->Status.dwServiceType = SERVICE_WIN32_OWN_PROCESS;
     Context->Status.dwServiceSpecificExitCode = 0;
@@ -327,7 +717,7 @@ MonitorMain(
                                      NULL);
 
     if (Context->StopEvent == NULL)
-        goto fail3;
+        goto fail4;
 
     Context->RequestEvent = CreateEvent(NULL,
                                         TRUE,
@@ -335,7 +725,7 @@ MonitorMain(
                                         NULL);
 
     if (Context->RequestEvent == NULL)
-        goto fail4;
+        goto fail5;
 
     Error = RegOpenKeyEx(HKEY_LOCAL_MACHINE,
                          REQUEST_KEY,
@@ -344,7 +734,7 @@ MonitorMain(
                          &Context->RequestKey);
 
     if (Error != ERROR_SUCCESS)
-        goto fail5;
+        goto fail6;
 
     SetEvent(Context->RequestEvent);
 
@@ -392,23 +782,26 @@ done:
 
     return;
 
-fail5:
-    Log("fail5");
+fail6:
+    Log("fail6");
 
     ReportStatus(SERVICE_STOPPED, GetLastError(), 0);
 
     CloseHandle(Context->RequestEvent);
 
+fail5:
+    Log("fail5");
+
+    CloseHandle(Context->StopEvent);
+
 fail4:
     Log("fail4");
 
-    CloseHandle(Context->StopEvent);
+    (VOID) DeregisterEventSource(Context->EventLog);
 
 fail3:
     Log("fail3");
 
-    (VOID) DeregisterEventSource(Context->EventLog);
-
 fail2:
     Log("fail2");
 
-- 
2.1.1


_______________________________________________
win-pv-devel mailing list
win-pv-devel@xxxxxxxxxxxxxxxxxxxx
https://lists.xenproject.org/cgi-bin/mailman/listinfo/win-pv-devel

 


Rackspace

Lists.xenproject.org is hosted with RackSpace, monitoring our
servers 24x7x365 and backed by RackSpace's Fanatical Support®.